10f297cadSHanhan Wang //===- DataLayoutPropagation.cpp -----------------------------------------===/// 20f297cadSHanhan Wang // 30f297cadSHanhan Wang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40f297cadSHanhan Wang // See https://llvm.org/LICENSE.txt for license information. 50f297cadSHanhan Wang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60f297cadSHanhan Wang // 70f297cadSHanhan Wang //===----------------------------------------------------------------------===// 80f297cadSHanhan Wang 90f297cadSHanhan Wang #include "mlir/Dialect/Linalg/Passes.h" 100f297cadSHanhan Wang 110f297cadSHanhan Wang #include "mlir/Dialect/Affine/IR/AffineOps.h" 120f297cadSHanhan Wang #include "mlir/Dialect/Linalg/IR/Linalg.h" 130f297cadSHanhan Wang #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 140f297cadSHanhan Wang #include "mlir/Dialect/Linalg/Utils/Utils.h" 150f297cadSHanhan Wang #include "mlir/Dialect/Tensor/IR/Tensor.h" 160f297cadSHanhan Wang #include "mlir/Dialect/Tensor/Utils/Utils.h" 170f297cadSHanhan Wang #include "mlir/Dialect/Utils/IndexingUtils.h" 181c228026SLorenzo Chelini #include "mlir/IR/Dominance.h" 190f297cadSHanhan Wang #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20a945f55dSAdam Siemieniuk #include "llvm/ADT/SetOperations.h" 21a945f55dSAdam Siemieniuk #include "llvm/ADT/SetVector.h" 220c1c0d53SJerry Wu #include "llvm/ADT/TypeSwitch.h" 23d38d6065SHanhan Wang #include "llvm/Support/Debug.h" 24a1fe1f5fSKazu Hirata #include <optional> 250f297cadSHanhan Wang 260f297cadSHanhan Wang namespace mlir { 270f297cadSHanhan Wang #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION 280f297cadSHanhan Wang #include "mlir/Dialect/Linalg/Passes.h.inc" 290f297cadSHanhan Wang } // namespace mlir 300f297cadSHanhan Wang 310f297cadSHanhan Wang using namespace mlir; 320f297cadSHanhan Wang using namespace mlir::linalg; 330f297cadSHanhan Wang 340f297cadSHanhan Wang #define DEBUG_TYPE "linalg-data-layout-propagation" 350f297cadSHanhan Wang 360f297cadSHanhan Wang namespace { 370f297cadSHanhan Wang 38b4563ee1SQuinn Dawkins static bool hasGatherSemantics(linalg::GenericOp genericOp) { 39b4563ee1SQuinn Dawkins for (Operation &op : genericOp.getBody()->getOperations()) 40b4563ee1SQuinn Dawkins if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) 41b4563ee1SQuinn Dawkins return true; 42b4563ee1SQuinn Dawkins return false; 43b4563ee1SQuinn Dawkins } 44b4563ee1SQuinn Dawkins 45d38d6065SHanhan Wang // The struct contains the infomation about mapping packing information to 46d38d6065SHanhan Wang // the iteration domain of Linalg ops. 47d38d6065SHanhan Wang struct PackInfo { 48d38d6065SHanhan Wang int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; 49d38d6065SHanhan Wang // InnerDimsPos on iteration domain, which follows the order in pack ops. 50d38d6065SHanhan Wang SmallVector<int64_t> tiledDimsPos; 51d38d6065SHanhan Wang // The sizes of tiling data dimensions on iteration domain. 52d38d6065SHanhan Wang llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; 53d38d6065SHanhan Wang // The mapping from a dimension of iteration domain to the corresponding inner 54d38d6065SHanhan Wang // tiling dimension on iteration domain. 55d38d6065SHanhan Wang llvm::DenseMap<int64_t, int64_t> tileToPointMapping; 56d38d6065SHanhan Wang // The permutation of outer dims (on domain). 57d38d6065SHanhan Wang SmallVector<int64_t> outerDimsOnDomainPerm; 58d38d6065SHanhan Wang }; 59d38d6065SHanhan Wang 606bb0ab0dSLorenzo Chelini template <typename OpTy> 61b4563ee1SQuinn Dawkins static FailureOr<PackInfo> 62b4563ee1SQuinn Dawkins getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, 636bb0ab0dSLorenzo Chelini OpTy packOrUnPackOp) { 646bb0ab0dSLorenzo Chelini static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value, 656bb0ab0dSLorenzo Chelini "applies to only pack or unpack operations"); 66d38d6065SHanhan Wang LLVM_DEBUG( 676bb0ab0dSLorenzo Chelini { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); 68b4563ee1SQuinn Dawkins 69b4563ee1SQuinn Dawkins AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); 70b4563ee1SQuinn Dawkins SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 71b4563ee1SQuinn Dawkins SmallVector<utils::IteratorType> iterators = 72b4563ee1SQuinn Dawkins genericOp.getIteratorTypesArray(); 73b4563ee1SQuinn Dawkins 74d38d6065SHanhan Wang PackInfo packInfo; 75d38d6065SHanhan Wang int64_t origNumDims = indexingMap.getNumDims(); 76d38d6065SHanhan Wang SmallVector<AffineExpr> exprs(indexingMap.getResults()); 776bb0ab0dSLorenzo Chelini ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); 78d38d6065SHanhan Wang for (auto [index, innerDimPos, tileSize] : 79d38d6065SHanhan Wang llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), 806bb0ab0dSLorenzo Chelini innerDimsPos, packOrUnPackOp.getMixedTiles())) { 81b4563ee1SQuinn Dawkins auto expr = exprs[innerDimPos]; 821609f1c2Slong.chen if (!isa<AffineDimExpr>(expr)) 83b4563ee1SQuinn Dawkins return failure(); 84d38d6065SHanhan Wang int64_t domainDimPos = 851609f1c2Slong.chen cast<AffineDimExpr>(exprs[innerDimPos]).getPosition(); 86b4563ee1SQuinn Dawkins if (!isParallelIterator(iterators[domainDimPos])) 87b4563ee1SQuinn Dawkins return failure(); 88d38d6065SHanhan Wang packInfo.tiledDimsPos.push_back(domainDimPos); 89d38d6065SHanhan Wang packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; 90d38d6065SHanhan Wang packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; 91d38d6065SHanhan Wang LLVM_DEBUG({ 92d38d6065SHanhan Wang llvm::dbgs() << "map innerDimPos=" << innerDimPos 93d38d6065SHanhan Wang << " to iteration dimension (d" << domainDimPos << ", d" 94d38d6065SHanhan Wang << packInfo.tileToPointMapping[domainDimPos] 95d38d6065SHanhan Wang << "), which has size=(" 96d38d6065SHanhan Wang << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; 97d38d6065SHanhan Wang }); 98d38d6065SHanhan Wang } 99d38d6065SHanhan Wang 100b4563ee1SQuinn Dawkins // Bail out if a tiled dimension is present in a map but not as an affine dim 101b4563ee1SQuinn Dawkins // expression. 102b4563ee1SQuinn Dawkins auto areAllAffineDimExpr = [&](int dim) { 103b4563ee1SQuinn Dawkins for (AffineMap map : indexingMaps) { 104b4563ee1SQuinn Dawkins if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { 1051609f1c2Slong.chen return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr); 106b4563ee1SQuinn Dawkins })) { 107b4563ee1SQuinn Dawkins return false; 108b4563ee1SQuinn Dawkins } 109b4563ee1SQuinn Dawkins } 110b4563ee1SQuinn Dawkins return true; 111b4563ee1SQuinn Dawkins }; 112b4563ee1SQuinn Dawkins for (int64_t i : packInfo.tiledDimsPos) 113b4563ee1SQuinn Dawkins if (!areAllAffineDimExpr(i)) 114b4563ee1SQuinn Dawkins return failure(); 115b4563ee1SQuinn Dawkins 116b4563ee1SQuinn Dawkins // Get the outer dims perm on the iteration domain. Start by identifying the 117b4563ee1SQuinn Dawkins // set of domain dims affected by the outer permutation along with the 118b4563ee1SQuinn Dawkins // permuted ordering for those dims. Then the full outer dims permutation can 119b4563ee1SQuinn Dawkins // be constructed by replacing the affected dims with the permuted result in a 120b4563ee1SQuinn Dawkins // numLoops-rank identity. e.g. 121b4563ee1SQuinn Dawkins // outerDimsPerm = [1, 2, 0] 122b4563ee1SQuinn Dawkins // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) 123b4563ee1SQuinn Dawkins // 124b4563ee1SQuinn Dawkins // permutedOuterDims = [4, 3, 1] 125b4563ee1SQuinn Dawkins // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] 126b4563ee1SQuinn Dawkins // 127b4563ee1SQuinn Dawkins // Non-affine dim expressions must not be permuted by the outer dims 128b4563ee1SQuinn Dawkins // permutation. 129b4563ee1SQuinn Dawkins SmallVector<int64_t> permutedOuterDims; 130b4563ee1SQuinn Dawkins for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { 131b4563ee1SQuinn Dawkins auto permutedExpr = indexingMap.getResult(dim); 1321609f1c2Slong.chen if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) { 133b4563ee1SQuinn Dawkins permutedOuterDims.push_back(dimExpr.getPosition()); 134b4563ee1SQuinn Dawkins continue; 135b4563ee1SQuinn Dawkins } 136b4563ee1SQuinn Dawkins 137b4563ee1SQuinn Dawkins // TODO: Allow propagation with transposes on non affine dim expressions, 138b4563ee1SQuinn Dawkins // e.g. d0 + d1 which implies transposing both dims simultaneously while 139b4563ee1SQuinn Dawkins // maintaining the relative position between them. 140b4563ee1SQuinn Dawkins if (static_cast<int64_t>(index) != dim) 141b4563ee1SQuinn Dawkins return failure(); 142b4563ee1SQuinn Dawkins } 143b4563ee1SQuinn Dawkins if (!permutedOuterDims.empty()) { 144b4563ee1SQuinn Dawkins int64_t outerDimIndex = 0; 145b4563ee1SQuinn Dawkins llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), 146b4563ee1SQuinn Dawkins permutedOuterDims.end()); 147b4563ee1SQuinn Dawkins for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) 148b4563ee1SQuinn Dawkins packInfo.outerDimsOnDomainPerm.push_back( 149b4563ee1SQuinn Dawkins permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] 150b4563ee1SQuinn Dawkins : i); 151d38d6065SHanhan Wang LLVM_DEBUG({ 152d38d6065SHanhan Wang llvm::dbgs() << "map outer dimsDimsPerm to "; 153d38d6065SHanhan Wang for (auto dim : packInfo.outerDimsOnDomainPerm) 154d38d6065SHanhan Wang llvm::dbgs() << dim << " "; 155d38d6065SHanhan Wang llvm::dbgs() << "\n"; 156d38d6065SHanhan Wang }); 157d38d6065SHanhan Wang } 158d38d6065SHanhan Wang 159d38d6065SHanhan Wang return packInfo; 160d38d6065SHanhan Wang } 161d38d6065SHanhan Wang 162d7904a70SLorenzo Chelini static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, 163d7904a70SLorenzo Chelini ArrayRef<AffineExpr> exprs) { 164d7904a70SLorenzo Chelini // Compute `outer_dims_perm`. See example: 165d7904a70SLorenzo Chelini // current exprs : (d0, d1, d2, d3) -> (d2, d3) 166d7904a70SLorenzo Chelini // perm : [0, 3, 1, 2] 167d7904a70SLorenzo Chelini // First map d2, d3 with their position in the array as: 168d7904a70SLorenzo Chelini // currentPositionTileLoops: dim | pos 169d7904a70SLorenzo Chelini // d2 | 0 170d7904a70SLorenzo Chelini // d3 | 1 171d7904a70SLorenzo Chelini // then scan `perm` in order and get the `outer_dims_perm` 172d7904a70SLorenzo Chelini // to be used, here it would be [1, 0]. 173d7904a70SLorenzo Chelini assert(!perm.empty() && "expect perm not to be empty"); 174d7904a70SLorenzo Chelini assert(!exprs.empty() && "expect exprs not to be empty"); 175d7904a70SLorenzo Chelini if (exprs.size() == 1) 176d7904a70SLorenzo Chelini return {}; 177d7904a70SLorenzo Chelini SmallVector<int64_t> outerDimsPerm; 178d7904a70SLorenzo Chelini DenseMap<int64_t, int64_t> currentPositionTileLoops; 179d7904a70SLorenzo Chelini for (auto [pos, expr] : llvm::enumerate(exprs)) { 180b4563ee1SQuinn Dawkins // Here we rely on the assumption that the outer dims permutation 181b4563ee1SQuinn Dawkins // when propagating currently requires that non-affine dim expressions 182b4563ee1SQuinn Dawkins // are not permuted, thus allowing the identity assignment below. 1831609f1c2Slong.chen if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) 184b4563ee1SQuinn Dawkins currentPositionTileLoops[dimExpr.getPosition()] = pos; 185b4563ee1SQuinn Dawkins else 186b4563ee1SQuinn Dawkins currentPositionTileLoops[pos] = pos; 187d7904a70SLorenzo Chelini } 188d7904a70SLorenzo Chelini for (int64_t loopIdx : perm) { 189d7904a70SLorenzo Chelini if (currentPositionTileLoops.count(loopIdx)) 190d7904a70SLorenzo Chelini outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); 191d7904a70SLorenzo Chelini } 192d7904a70SLorenzo Chelini return outerDimsPerm; 193d7904a70SLorenzo Chelini } 194d7904a70SLorenzo Chelini 1950f297cadSHanhan Wang /// Returns a tuple for packed operand and indexing_map with the assumptions: 1960f297cadSHanhan Wang /// 1) The generic op is the producer of the pack op. 1970f297cadSHanhan Wang /// 2) The generic op has only one result. 1980f297cadSHanhan Wang /// If the operand is a scalar or packing dimensions are all irrelevant to the 199d7904a70SLorenzo Chelini /// operand, the operand and the updated indexing map will be returned. 2000f297cadSHanhan Wang /// Otherwise, it returns the packed operand and the updated indexing map. E.g., 2010f297cadSHanhan Wang /// 2020f297cadSHanhan Wang /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 2030f297cadSHanhan Wang /// #map1 = affine_map<(d0, d1) -> (d0)> 2040f297cadSHanhan Wang /// #map2 = affine_map<(d0, d1) -> (d1)> 2050f297cadSHanhan Wang /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], 2060f297cadSHanhan Wang /// iterator_types = ["parallel", "parallel"]} 2070f297cadSHanhan Wang /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) 2080f297cadSHanhan Wang /// outs(%init : tensor<?x?xf32>) { 2090f297cadSHanhan Wang /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): 2100f297cadSHanhan Wang /// %4 = arith.addf %arg3, %arg4 : f32 2110f297cadSHanhan Wang /// linalg.yield %4 : f32 2120f297cadSHanhan Wang /// } -> tensor<?x?xf32> 2130f297cadSHanhan Wang /// %1 = tensor.pack %0 2140f297cadSHanhan Wang /// inner_dims_pos = [0, 1] 2150f297cadSHanhan Wang /// inner_tiles = [8, 2] 2160f297cadSHanhan Wang /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 2170f297cadSHanhan Wang /// 2180f297cadSHanhan Wang /// Taking the first input operand as an example, the inner tile size of d1 is 2190f297cadSHanhan Wang /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> 2200f297cadSHanhan Wang /// affine_map<(d1, d3)>` will be returned. 2210f297cadSHanhan Wang /// 2220f297cadSHanhan Wang /// %pack = tensor.pack %arg0 2230f297cadSHanhan Wang /// inner_dims_pos = [0] 2240f297cadSHanhan Wang /// inner_tiles = [8] 2250f297cadSHanhan Wang /// into %init : tensor<?xf32> -> tensor<?x8xf32> 2260f297cadSHanhan Wang static std::tuple<Value, AffineMap> 227d38d6065SHanhan Wang getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, 228d38d6065SHanhan Wang GenericOp genericOp, OpOperand *opOperand) { 229d38d6065SHanhan Wang int64_t numOrigLoops = genericOp.getNumLoops(); 230d38d6065SHanhan Wang int64_t numInnerLoops = packInfo.getNumTiledLoops(); 2310f297cadSHanhan Wang int64_t numLoops = numOrigLoops + numInnerLoops; 2320f297cadSHanhan Wang AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); 233d38d6065SHanhan Wang llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; 2340f297cadSHanhan Wang SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); 2359f242404SLorenzo Chelini 2369f242404SLorenzo Chelini // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. 2373cf42c3fSAdrian Kuegel if (genericOp.isScalar(opOperand) || exprs.empty()) 238d38d6065SHanhan Wang return std::make_tuple(opOperand->get(), 239d38d6065SHanhan Wang AffineMap::get(numLoops, 0, exprs, b.getContext())); 2400f297cadSHanhan Wang 241d38d6065SHanhan Wang // Step 1. Construct the information of packing data dimensions; append inner 242d38d6065SHanhan Wang // dimensions to the indexing maps for the operand. 243d38d6065SHanhan Wang for (auto [index, expr] : llvm::enumerate(exprs)) { 2441609f1c2Slong.chen if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 2455f2618feSQuinn Dawkins int64_t dimPos = dimExpr.getPosition(); 246d38d6065SHanhan Wang domainDimToOperandDim[dimPos] = index; 2475f2618feSQuinn Dawkins continue; 2485f2618feSQuinn Dawkins } 2490f297cadSHanhan Wang } 2500f297cadSHanhan Wang SmallVector<int64_t> innerDimsPos; 2510f297cadSHanhan Wang SmallVector<OpFoldResult> innerTileSizes; 252d38d6065SHanhan Wang for (auto dimPos : packInfo.tiledDimsPos) { 253d38d6065SHanhan Wang if (!domainDimToOperandDim.count(dimPos)) 2540f297cadSHanhan Wang continue; 255d38d6065SHanhan Wang int64_t index = domainDimToOperandDim[dimPos]; 256d38d6065SHanhan Wang innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); 257d38d6065SHanhan Wang innerDimsPos.push_back(index); 258d38d6065SHanhan Wang exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); 2590f297cadSHanhan Wang } 2600f297cadSHanhan Wang 261d7904a70SLorenzo Chelini // Step 2. Handle outer dim permutations. 2620f297cadSHanhan Wang SmallVector<int64_t> outerDimsPerm; 263d38d6065SHanhan Wang if (!packInfo.outerDimsOnDomainPerm.empty()) { 264d7904a70SLorenzo Chelini outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); 265d7904a70SLorenzo Chelini 266d7904a70SLorenzo Chelini // Step 2.1: Fold transpose into the linalg.generic. 267d38d6065SHanhan Wang SmallVector<int64_t> inversedOuterPerm = 268d38d6065SHanhan Wang invertPermutationVector(packInfo.outerDimsOnDomainPerm); 269d38d6065SHanhan Wang for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { 2701609f1c2Slong.chen if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) { 2715f2618feSQuinn Dawkins int64_t dimPos = dimExpr.getPosition(); 272d38d6065SHanhan Wang exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); 2735f2618feSQuinn Dawkins continue; 2745f2618feSQuinn Dawkins } 2751609f1c2Slong.chen assert(isa<AffineConstantExpr>(exprs[i]) && 2765f2618feSQuinn Dawkins "Attempted to permute non-constant and non-affine dim expression"); 2770f297cadSHanhan Wang } 278d7904a70SLorenzo Chelini // Step 2.2: Undo the transposition on `exprs` and propagate the 279d7904a70SLorenzo Chelini // transposition on the pack using outerDimsPerm. 280d7904a70SLorenzo Chelini if (!outerDimsPerm.empty()) { 281d7904a70SLorenzo Chelini SmallVector<AffineExpr> auxVec = exprs; 282d7904a70SLorenzo Chelini for (const auto &en : enumerate(outerDimsPerm)) 283d7904a70SLorenzo Chelini auxVec[en.index()] = exprs[en.value()]; 284d7904a70SLorenzo Chelini exprs = auxVec; 285d7904a70SLorenzo Chelini } 286d38d6065SHanhan Wang } 287d38d6065SHanhan Wang auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); 2880f297cadSHanhan Wang 2890f297cadSHanhan Wang // The operand does not have dimensions that relates to pack op. 290b4563ee1SQuinn Dawkins if (innerDimsPos.empty() && outerDimsPerm.empty()) 2910f297cadSHanhan Wang return std::make_tuple(opOperand->get(), indexingMap); 2920f297cadSHanhan Wang 2930f297cadSHanhan Wang auto empty = tensor::PackOp::createDestinationTensor( 2940f297cadSHanhan Wang b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); 2950f297cadSHanhan Wang auto packedOperand = b.create<tensor::PackOp>( 2960f297cadSHanhan Wang loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, 297ff6f4ae7SLorenzo Chelini /*padding=*/std::nullopt, outerDimsPerm); 2980f297cadSHanhan Wang return std::make_tuple(packedOperand, indexingMap); 2990f297cadSHanhan Wang } 3000f297cadSHanhan Wang 3019f242404SLorenzo Chelini /// Pack a genericOp and return it. 3029f242404SLorenzo Chelini static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, 3036bb0ab0dSLorenzo Chelini Value dest, AffineMap packedOutIndexingMap, 3046bb0ab0dSLorenzo Chelini const PackInfo &packInfo) { 3056bb0ab0dSLorenzo Chelini Location loc = genericOp.getLoc(); 3066bb0ab0dSLorenzo Chelini SmallVector<Value> inputOperands; 3076bb0ab0dSLorenzo Chelini SmallVector<AffineMap> indexingMaps; 3086bb0ab0dSLorenzo Chelini for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { 3096bb0ab0dSLorenzo Chelini auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( 3106bb0ab0dSLorenzo Chelini rewriter, loc, packInfo, genericOp, inputOperand); 3116bb0ab0dSLorenzo Chelini inputOperands.push_back(packedOperand); 3126bb0ab0dSLorenzo Chelini indexingMaps.push_back(packedIndexingMap); 3136bb0ab0dSLorenzo Chelini } 3146bb0ab0dSLorenzo Chelini 3156bb0ab0dSLorenzo Chelini int64_t numInnerLoops = packInfo.getNumTiledLoops(); 3166bb0ab0dSLorenzo Chelini SmallVector<utils::IteratorType> iterTypes = 3176bb0ab0dSLorenzo Chelini genericOp.getIteratorTypesArray(); 3186bb0ab0dSLorenzo Chelini iterTypes.append(numInnerLoops, utils::IteratorType::parallel); 3196bb0ab0dSLorenzo Chelini 3206bb0ab0dSLorenzo Chelini indexingMaps.push_back(packedOutIndexingMap); 3216bb0ab0dSLorenzo Chelini 3226bb0ab0dSLorenzo Chelini auto newGenericOp = rewriter.create<linalg::GenericOp>( 3236bb0ab0dSLorenzo Chelini loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, 3246bb0ab0dSLorenzo Chelini /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); 3256bb0ab0dSLorenzo Chelini rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), 3266bb0ab0dSLorenzo Chelini newGenericOp.getRegion().begin()); 3276bb0ab0dSLorenzo Chelini return newGenericOp; 3286bb0ab0dSLorenzo Chelini } 3296bb0ab0dSLorenzo Chelini 330b4563ee1SQuinn Dawkins /// Bubbles up tensor.pack op through a producer generic op. This 3310f297cadSHanhan Wang /// swap pack(generic) to generic(pack). The new generic op works on packed 3320f297cadSHanhan Wang /// domain; pack ops are created for input and output operands. E.g., 3330f297cadSHanhan Wang /// 3340f297cadSHanhan Wang /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 3350f297cadSHanhan Wang /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 3360f297cadSHanhan Wang /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 3370f297cadSHanhan Wang /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 3380f297cadSHanhan Wang /// %3 = linalg.generic {indexing_maps = [#map0, #map0], 3390f297cadSHanhan Wang /// iterator_types = ["parallel", "parallel"]} 3400f297cadSHanhan Wang /// ins(%arg0 : tensor<?x?xf32>) 3410f297cadSHanhan Wang /// outs(%2 : tensor<?x?xf32>) { 3420f297cadSHanhan Wang /// ^bb0(%arg3: f32, %arg4: f32): 3430f297cadSHanhan Wang /// %4 = arith.addf %arg3, %arg3 : f32 3440f297cadSHanhan Wang /// linalg.yield %4 : f32 3450f297cadSHanhan Wang /// } -> tensor<?x?xf32> 3460f297cadSHanhan Wang /// %4 = tensor.pack %3 3470f297cadSHanhan Wang /// inner_dims_pos = [0, 1] 3480f297cadSHanhan Wang /// inner_tiles = [8, 2] 3490f297cadSHanhan Wang /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 3500f297cadSHanhan Wang /// 3510f297cadSHanhan Wang /// will be converted to 3520f297cadSHanhan Wang /// 3530f297cadSHanhan Wang /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> 3540f297cadSHanhan Wang /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> 3550f297cadSHanhan Wang /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 3560f297cadSHanhan Wang /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> 3570f297cadSHanhan Wang /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 3580f297cadSHanhan Wang /// %0 = affine.apply #map()[%dim] 3590f297cadSHanhan Wang /// %1 = affine.apply #map1()[%dim_0] 3600f297cadSHanhan Wang /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> 3610f297cadSHanhan Wang /// %pack = tensor.pack %arg0 3620f297cadSHanhan Wang /// inner_dims_pos = [0, 1] 3630f297cadSHanhan Wang /// inner_tiles = [8, 2] 3640f297cadSHanhan Wang /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 3650f297cadSHanhan Wang /// %3 = linalg.generic {indexing_maps = [#map2, #map2], 3660f297cadSHanhan Wang /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 3670f297cadSHanhan Wang /// ins(%pack : tensor<?x?x8x2xf32>) 3680f297cadSHanhan Wang /// outs(%arg1 : tensor<?x?x8x2xf32>) { 3690f297cadSHanhan Wang /// ^bb0(%in: f32, %out: f32): 3700f297cadSHanhan Wang /// %4 = arith.addf %in, %in : f32 3710f297cadSHanhan Wang /// linalg.yield %4 : f32 3720f297cadSHanhan Wang /// } -> tensor<?x?x8x2xf32> 3730f297cadSHanhan Wang static FailureOr<GenericOp> 374b4563ee1SQuinn Dawkins bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, 3753b61f5a1SMehdi Amini const ControlPropagationFn &controlFn) { 3760f297cadSHanhan Wang auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); 377b4563ee1SQuinn Dawkins if (!genericOp) 378b4563ee1SQuinn Dawkins return failure(); 379b4563ee1SQuinn Dawkins 380b4563ee1SQuinn Dawkins // User controlled propagation function. 38104fc471fSHan-Chung Wang if (!controlFn(&packOp.getSourceMutable())) 382b4563ee1SQuinn Dawkins return failure(); 383b4563ee1SQuinn Dawkins 384b4563ee1SQuinn Dawkins // TODO: Enable propagation in the presence of linalg.index and 385b4563ee1SQuinn Dawkins // tensor.extract, likely as a separate pattern as the pack information and 386b4563ee1SQuinn Dawkins // propagation decision needs to be inferred from the region of the generic. 387b4563ee1SQuinn Dawkins if (hasGatherSemantics(genericOp)) 3880f297cadSHanhan Wang return failure(); 3890f297cadSHanhan Wang 3900f297cadSHanhan Wang // TODO: Relax the restriction. We are able to bubble up the pack op through 3910f297cadSHanhan Wang // multi-result generic op. It just needs more work. 3920f297cadSHanhan Wang if (genericOp.getNumResults() != 1) 3930f297cadSHanhan Wang return failure(); 3940f297cadSHanhan Wang 3951c228026SLorenzo Chelini // Bail-out if the result of the generic has multiple uses, as bubbling up 3961c228026SLorenzo Chelini // creates recomputation if the generic has multiple users. 397b4563ee1SQuinn Dawkins // TODO: Enable the case where every use is an identical pack op as no 398b4563ee1SQuinn Dawkins // recomputation is needed in that case. 3991c228026SLorenzo Chelini if (!genericOp->getResult(0).hasOneUse()) 4001c228026SLorenzo Chelini return failure(); 4011c228026SLorenzo Chelini 4021c228026SLorenzo Chelini // We want to move the pack not the generic. 4031c228026SLorenzo Chelini OpBuilder::InsertionGuard guard(rewriter); 4041c228026SLorenzo Chelini rewriter.setInsertionPoint(genericOp); 4051c228026SLorenzo Chelini 4061c228026SLorenzo Chelini // We need to handle two cases: 4071c228026SLorenzo Chelini // 1) The tensor.pack destination is a tensor.empty. If this is the case, we 4081c228026SLorenzo Chelini // create a new tensor.empty to avoid breaking dominance, as we are moving the 4091c228026SLorenzo Chelini // tensor.pack above the linalg.generic. 4101c228026SLorenzo Chelini // 2) The destination is not a tensor.empty. In this case we can replace only 4111c228026SLorenzo Chelini // if the destination of the tensor.pack dominates the linalg.generic. 4121c228026SLorenzo Chelini Value packOpDest = packOp.getDest(); 4131c228026SLorenzo Chelini if (!packOpDest.hasOneUse()) 4141c228026SLorenzo Chelini return failure(); 4151c228026SLorenzo Chelini if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { 4161c228026SLorenzo Chelini packOpDest = rewriter.create<tensor::EmptyOp>( 4171c228026SLorenzo Chelini genericOp->getLoc(), emptyOp.getMixedSizes(), 4181c228026SLorenzo Chelini emptyOp.getType().getElementType()); 4191c228026SLorenzo Chelini } else { 4201c228026SLorenzo Chelini DominanceInfo dom(genericOp); 4211c228026SLorenzo Chelini if (!dom.properlyDominates(packOpDest, genericOp)) 4221c228026SLorenzo Chelini return failure(); 4231c228026SLorenzo Chelini } 4241c228026SLorenzo Chelini 4250f297cadSHanhan Wang // TODO: Add an option for allowing padding values. It could introduce 4260f297cadSHanhan Wang // undefined behavior if we unconditionally propagate pack op through all 4270f297cadSHanhan Wang // the ops. E.g., if the padding value is zero and there are division ops in 4280f297cadSHanhan Wang // a generic op. Some values of padding area could be NaN (0/0). 4290f297cadSHanhan Wang if (packOp.getPaddingValue()) 4300f297cadSHanhan Wang return failure(); 4310f297cadSHanhan Wang 4320f297cadSHanhan Wang OpOperand *opOperand = genericOp.getDpsInitOperand(0); 433b4563ee1SQuinn Dawkins auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); 434b4563ee1SQuinn Dawkins if (failed(packInfo)) 435b4563ee1SQuinn Dawkins return failure(); 4360f297cadSHanhan Wang 437d38d6065SHanhan Wang // Rebuild the indexing map for the corresponding init operand. 438d38d6065SHanhan Wang auto [packedOutOperand, packedOutIndexingMap] = 439b4563ee1SQuinn Dawkins getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 4406bb0ab0dSLorenzo Chelini genericOp, opOperand); 4410f297cadSHanhan Wang 44261be9358SLorenzo Chelini // If the dps init operand of the generic is a tensor.empty forward the pack 44361be9358SLorenzo Chelini // op destination. 44461be9358SLorenzo Chelini Value dest = packedOutOperand; 44561be9358SLorenzo Chelini if (auto initTensor = genericOp.getDpsInitOperand(0) 44661be9358SLorenzo Chelini ->get() 44761be9358SLorenzo Chelini .getDefiningOp<tensor::EmptyOp>()) { 44861be9358SLorenzo Chelini dest = packOpDest; 44961be9358SLorenzo Chelini } 4509f242404SLorenzo Chelini return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, 451b4563ee1SQuinn Dawkins *packInfo); 4520f297cadSHanhan Wang } 4530f297cadSHanhan Wang 454b4563ee1SQuinn Dawkins /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. 455b4563ee1SQuinn Dawkins struct BubbleUpPackOpThroughGenericOpPattern 4560f297cadSHanhan Wang : public OpRewritePattern<tensor::PackOp> { 457b4563ee1SQuinn Dawkins public: 458b4563ee1SQuinn Dawkins BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, 459b4563ee1SQuinn Dawkins ControlPropagationFn fun) 460b4563ee1SQuinn Dawkins : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 4610f297cadSHanhan Wang 4620f297cadSHanhan Wang LogicalResult matchAndRewrite(tensor::PackOp packOp, 4630f297cadSHanhan Wang PatternRewriter &rewriter) const override { 464b4563ee1SQuinn Dawkins auto genericOp = 465b4563ee1SQuinn Dawkins bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); 4660f297cadSHanhan Wang if (failed(genericOp)) 4670f297cadSHanhan Wang return failure(); 468cbb09813SFangrui Song rewriter.replaceOp(packOp, genericOp->getResults()); 4690f297cadSHanhan Wang return success(); 4700f297cadSHanhan Wang } 471b4563ee1SQuinn Dawkins 472b4563ee1SQuinn Dawkins private: 473b4563ee1SQuinn Dawkins ControlPropagationFn controlFn; 4740f297cadSHanhan Wang }; 4756bb0ab0dSLorenzo Chelini 476886294a2SQuinn Dawkins /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to 477886294a2SQuinn Dawkins /// add as many zero padding dimensions in `high` and `low` based on the number 478886294a2SQuinn Dawkins /// of point loops. 479886294a2SQuinn Dawkins class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> { 480886294a2SQuinn Dawkins public: 481886294a2SQuinn Dawkins BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 482886294a2SQuinn Dawkins : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 483886294a2SQuinn Dawkins 484886294a2SQuinn Dawkins LogicalResult matchAndRewrite(tensor::PackOp packOp, 485886294a2SQuinn Dawkins PatternRewriter &rewriter) const override { 486886294a2SQuinn Dawkins auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>(); 487886294a2SQuinn Dawkins if (!padOp) 488886294a2SQuinn Dawkins return failure(); 489886294a2SQuinn Dawkins 490886294a2SQuinn Dawkins // User controlled propagation function. 49104fc471fSHan-Chung Wang if (!controlFn(&packOp.getSourceMutable())) 492886294a2SQuinn Dawkins return failure(); 493886294a2SQuinn Dawkins 494886294a2SQuinn Dawkins // TODO: Enable padding when the padding values are the same. 495886294a2SQuinn Dawkins if (packOp.getPaddingValue()) 496886294a2SQuinn Dawkins return failure(); 497886294a2SQuinn Dawkins 498886294a2SQuinn Dawkins // Fail for non-constant padding values. The body of the pad could 499886294a2SQuinn Dawkins // depend on the padding indices and/or properties of the padded 500886294a2SQuinn Dawkins // tensor so for now we fail. 501886294a2SQuinn Dawkins // TODO: Support non-constant padding values. 502886294a2SQuinn Dawkins Value paddingVal = padOp.getConstantPaddingValue(); 503886294a2SQuinn Dawkins if (!paddingVal) 504886294a2SQuinn Dawkins return failure(); 505886294a2SQuinn Dawkins 506886294a2SQuinn Dawkins if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>()) 507886294a2SQuinn Dawkins return failure(); 508886294a2SQuinn Dawkins 509886294a2SQuinn Dawkins ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 510886294a2SQuinn Dawkins 511886294a2SQuinn Dawkins // Bail out if one of the padded dimension is a tiled one. 512886294a2SQuinn Dawkins llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 513886294a2SQuinn Dawkins llvm::SmallBitVector innerDims(paddedDims.size()); 514886294a2SQuinn Dawkins for (int64_t dim : innerDimsPos) 515886294a2SQuinn Dawkins innerDims.flip(dim); 516886294a2SQuinn Dawkins if (paddedDims.anyCommon(innerDims)) 517886294a2SQuinn Dawkins return failure(); 518886294a2SQuinn Dawkins 519886294a2SQuinn Dawkins Location loc = padOp->getLoc(); 520886294a2SQuinn Dawkins OpBuilder::InsertionGuard guard(rewriter); 521886294a2SQuinn Dawkins rewriter.setInsertionPoint(padOp); 522886294a2SQuinn Dawkins 5234ad96785SQuinn Dawkins ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 5244ad96785SQuinn Dawkins SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles(); 525886294a2SQuinn Dawkins auto empty = tensor::PackOp::createDestinationTensor( 5264ad96785SQuinn Dawkins rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, 527886294a2SQuinn Dawkins outerDimsPerm); 5284ad96785SQuinn Dawkins auto sourcePack = rewriter.create<tensor::PackOp>( 5294ad96785SQuinn Dawkins loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, 530886294a2SQuinn Dawkins /*padding=*/std::nullopt, outerDimsPerm); 531886294a2SQuinn Dawkins 532886294a2SQuinn Dawkins // If we have `outer_dims_perms` we need to adjust the padded dimensions. 533886294a2SQuinn Dawkins SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 534886294a2SQuinn Dawkins SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 535886294a2SQuinn Dawkins if (!outerDimsPerm.empty()) { 536886294a2SQuinn Dawkins applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 537886294a2SQuinn Dawkins applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 538886294a2SQuinn Dawkins } 539886294a2SQuinn Dawkins // The tiled dimensions were verified to be unpadded above, so here we 540886294a2SQuinn Dawkins // just append 0 for the inner tile dimensions. 541886294a2SQuinn Dawkins size_t pointLoopsSize = innerDimsPos.size(); 542886294a2SQuinn Dawkins lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 543886294a2SQuinn Dawkins highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 544886294a2SQuinn Dawkins 545886294a2SQuinn Dawkins auto newPadOp = rewriter.create<tensor::PadOp>( 5464ad96785SQuinn Dawkins loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal, 547886294a2SQuinn Dawkins padOp.getNofold()); 5484ad96785SQuinn Dawkins 5494ad96785SQuinn Dawkins // If the pad has more than one user, create an unpack on the new pad to 5504ad96785SQuinn Dawkins // replace the other uses. 5514ad96785SQuinn Dawkins if (!padOp->hasOneUse()) { 5524ad96785SQuinn Dawkins auto unpackEmpty = tensor::UnPackOp::createDestinationTensor( 5534ad96785SQuinn Dawkins rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm); 5544ad96785SQuinn Dawkins Value unpackedPad = rewriter.create<tensor::UnPackOp>( 5554ad96785SQuinn Dawkins loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm); 5564ad96785SQuinn Dawkins rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack); 5574ad96785SQuinn Dawkins } 5584ad96785SQuinn Dawkins 5594ad96785SQuinn Dawkins // Replace the pack with the new pad. 560886294a2SQuinn Dawkins rewriter.replaceOp(packOp, newPadOp.getResult()); 5614ad96785SQuinn Dawkins 562886294a2SQuinn Dawkins return success(); 563886294a2SQuinn Dawkins } 564886294a2SQuinn Dawkins 565886294a2SQuinn Dawkins private: 566886294a2SQuinn Dawkins ControlPropagationFn controlFn; 567886294a2SQuinn Dawkins }; 568886294a2SQuinn Dawkins 5690c1c0d53SJerry Wu /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices. 5700c1c0d53SJerry Wu /// 5710c1c0d53SJerry Wu /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and 5720c1c0d53SJerry Wu /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the 5730c1c0d53SJerry Wu /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most 5740c1c0d53SJerry Wu /// non-unit projected dims in pos [2, 3] is 2. 5750c1c0d53SJerry Wu /// 5760c1c0d53SJerry Wu /// If all candidates in a reassociation are unit dims, it chooses the 5770c1c0d53SJerry Wu /// inner-most dim pos. 5780c1c0d53SJerry Wu static SmallVector<int64_t> 5790c1c0d53SJerry Wu projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos, 5800c1c0d53SJerry Wu ArrayRef<ReassociationIndices> reassocIndices, 5810c1c0d53SJerry Wu ArrayRef<int64_t> targetShape) { 5820c1c0d53SJerry Wu SmallVector<int64_t> projectedDimsPos; 5830c1c0d53SJerry Wu for (auto pos : dimsPos) { 5840c1c0d53SJerry Wu // In the case all dims are unit, this will return the inner-most one. 5850c1c0d53SJerry Wu int64_t projectedPos = reassocIndices[pos].back(); 5860c1c0d53SJerry Wu for (auto i : llvm::reverse(reassocIndices[pos])) { 5870c1c0d53SJerry Wu int64_t dim = targetShape[i]; 5880c1c0d53SJerry Wu if (dim > 1 || ShapedType::isDynamic(dim)) { 5890c1c0d53SJerry Wu projectedPos = i; 5900c1c0d53SJerry Wu break; 5910c1c0d53SJerry Wu } 5920c1c0d53SJerry Wu } 5930c1c0d53SJerry Wu projectedDimsPos.push_back(projectedPos); 5940c1c0d53SJerry Wu } 5950c1c0d53SJerry Wu return projectedDimsPos; 5960c1c0d53SJerry Wu } 5970c1c0d53SJerry Wu 5980c1c0d53SJerry Wu /// Check if all dims in dimsPos are divisible by the corresponding tile sizes. 5990c1c0d53SJerry Wu static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos, 6000c1c0d53SJerry Wu ArrayRef<int64_t> shape, 6010c1c0d53SJerry Wu ArrayRef<int64_t> tileSizes) { 6020c1c0d53SJerry Wu for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { 6030c1c0d53SJerry Wu int64_t dim = shape[pos]; 6040c1c0d53SJerry Wu if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) 6050c1c0d53SJerry Wu return false; 6060c1c0d53SJerry Wu } 6070c1c0d53SJerry Wu return true; 6080c1c0d53SJerry Wu } 6090c1c0d53SJerry Wu 6100c1c0d53SJerry Wu /// Permutate the reassociation indices and reindex them in the sequence order. 6110c1c0d53SJerry Wu /// Returns the next dim pos in the sequence. 6120c1c0d53SJerry Wu /// 6130c1c0d53SJerry Wu /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it 6140c1c0d53SJerry Wu /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into 6150c1c0d53SJerry Wu /// [[0], [1, 2]]. 6160c1c0d53SJerry Wu static int64_t applyPermutationAndReindexReassoc( 6170c1c0d53SJerry Wu SmallVector<ReassociationIndices> &reassocIndices, 6180c1c0d53SJerry Wu ArrayRef<int64_t> permutation) { 619002e8192Syifeizh2 if (!permutation.empty()) 6200c1c0d53SJerry Wu applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation); 6210c1c0d53SJerry Wu int64_t nextPos = 0; 6220c1c0d53SJerry Wu for (ReassociationIndices &indices : reassocIndices) { 6230c1c0d53SJerry Wu for (auto &index : indices) { 6240c1c0d53SJerry Wu index = nextPos; 6250c1c0d53SJerry Wu nextPos += 1; 6260c1c0d53SJerry Wu } 6270c1c0d53SJerry Wu } 6280c1c0d53SJerry Wu return nextPos; 6290c1c0d53SJerry Wu } 6300c1c0d53SJerry Wu 6310c1c0d53SJerry Wu /// Bubble up pack op through collapse shape op when the packed dims can be 6320c1c0d53SJerry Wu /// projected to the dims before collapsing. This is possible when the inner 6330c1c0d53SJerry Wu /// tile sizes can divide the projected dims. 6340c1c0d53SJerry Wu /// 6350c1c0d53SJerry Wu /// For example: 6360c1c0d53SJerry Wu /// 6370c1c0d53SJerry Wu /// %collapsed = tensor.collapse_shape %in [[0, 1], 2] 6380c1c0d53SJerry Wu /// : tensor<?x16x4xf32> into tensor<?x4xf32> 6390c1c0d53SJerry Wu /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] 6400c1c0d53SJerry Wu /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty 6410c1c0d53SJerry Wu /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32> 6420c1c0d53SJerry Wu /// 6430c1c0d53SJerry Wu /// can be transformed into: 6440c1c0d53SJerry Wu /// 6450c1c0d53SJerry Wu /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 6460c1c0d53SJerry Wu /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty 6470c1c0d53SJerry Wu /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> 6480c1c0d53SJerry Wu /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4] 6490c1c0d53SJerry Wu /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1> 6500c1c0d53SJerry Wu static LogicalResult 6510c1c0d53SJerry Wu bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, 6520c1c0d53SJerry Wu tensor::PackOp packOp, 6530c1c0d53SJerry Wu PatternRewriter &rewriter) { 6540c1c0d53SJerry Wu SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles(); 6550c1c0d53SJerry Wu ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 6560c1c0d53SJerry Wu ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 6570c1c0d53SJerry Wu 6580c1c0d53SJerry Wu ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape(); 6590c1c0d53SJerry Wu SmallVector<ReassociationIndices> reassocIndices = 6600c1c0d53SJerry Wu collapseOp.getReassociationIndices(); 6610c1c0d53SJerry Wu // Project inner tile pos to the dim pos before collapsing. For example, if 6620c1c0d53SJerry Wu // dims [x, y] is collapsed into [z], packing on dim z can be projected back 6630c1c0d53SJerry Wu // to pack on dim y. 6640c1c0d53SJerry Wu // 6650c1c0d53SJerry Wu // Project to inner-most non-unit dims to increase the chance that they can be 6660c1c0d53SJerry Wu // divided by the inner tile sizes. This is correct because for [..., x, 1], 6670c1c0d53SJerry Wu // packing on dim 1 is equivalent to packing on dim x. 6680c1c0d53SJerry Wu SmallVector<int64_t> projectedInnerDimsPos = 6690c1c0d53SJerry Wu projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape); 6700c1c0d53SJerry Wu 6710c1c0d53SJerry Wu if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape, 6720c1c0d53SJerry Wu innerTileSizes)) { 6730c1c0d53SJerry Wu return failure(); 6740c1c0d53SJerry Wu } 6750c1c0d53SJerry Wu // Expand the outer dims permutation with the associated source dims for the 6760c1c0d53SJerry Wu // new permutation after bubbling. This is because moving a collapsed dim is 6770c1c0d53SJerry Wu // equivalent to moving the associated source dims together. 6780c1c0d53SJerry Wu SmallVector<int64_t> newOuterDimsPerm; 6790c1c0d53SJerry Wu for (auto outerPos : outerDimsPerm) { 6800c1c0d53SJerry Wu newOuterDimsPerm.insert(newOuterDimsPerm.end(), 6810c1c0d53SJerry Wu reassocIndices[outerPos].begin(), 6820c1c0d53SJerry Wu reassocIndices[outerPos].end()); 6830c1c0d53SJerry Wu } 6840c1c0d53SJerry Wu 6850c1c0d53SJerry Wu auto emptyOp = tensor::PackOp::createDestinationTensor( 6860c1c0d53SJerry Wu rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), 6870c1c0d53SJerry Wu projectedInnerDimsPos, newOuterDimsPerm); 6880c1c0d53SJerry Wu auto newPackOp = rewriter.create<tensor::PackOp>( 6890c1c0d53SJerry Wu packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, 6900c1c0d53SJerry Wu packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); 6910c1c0d53SJerry Wu 6920c1c0d53SJerry Wu SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 6930c1c0d53SJerry Wu // First apply the permutation on the reassociations of the outer dims. 6940c1c0d53SJerry Wu // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 6950c1c0d53SJerry Wu // -> [[0], [1, 2]] 6960c1c0d53SJerry Wu int64_t nextPos = 6970c1c0d53SJerry Wu applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 6980c1c0d53SJerry Wu // Then add direct mapping for the inner tile dims. 6990c1c0d53SJerry Wu for (size_t i = 0; i < innerDimsPos.size(); ++i) { 7000c1c0d53SJerry Wu newReassocIndices.push_back({nextPos}); 7010c1c0d53SJerry Wu nextPos += 1; 7020c1c0d53SJerry Wu } 7030c1c0d53SJerry Wu 7040c1c0d53SJerry Wu auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( 7050c1c0d53SJerry Wu collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); 7060c1c0d53SJerry Wu rewriter.replaceOp(packOp, newCollapseOp); 7070c1c0d53SJerry Wu 7080c1c0d53SJerry Wu return success(); 7090c1c0d53SJerry Wu } 7100c1c0d53SJerry Wu 711a945f55dSAdam Siemieniuk /// Project dimsPos to their collapsed positions in the reassocIndices. 712a945f55dSAdam Siemieniuk /// 713a945f55dSAdam Siemieniuk /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices 714a945f55dSAdam Siemieniuk /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0, 715a945f55dSAdam Siemieniuk /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos 716a945f55dSAdam Siemieniuk /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3. 717a945f55dSAdam Siemieniuk static SmallVector<int64_t> 718a945f55dSAdam Siemieniuk projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos, 719a945f55dSAdam Siemieniuk ArrayRef<ReassociationIndices> reassocIndices) { 720a945f55dSAdam Siemieniuk SmallVector<int64_t> projectedPos; 721a945f55dSAdam Siemieniuk 722a945f55dSAdam Siemieniuk // Map each dimension to the position of corresponding reassociation index. 723a945f55dSAdam Siemieniuk for (auto pos : dimsPos) { 724a945f55dSAdam Siemieniuk for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { 725a945f55dSAdam Siemieniuk // If the dimension is present in the current indices group, the group 726a945f55dSAdam Siemieniuk // position within the reassociation map is the desired projected 727a945f55dSAdam Siemieniuk // dimension position. 728*165f4535SKazu Hirata if (llvm::is_contained(indices, pos)) { 729a945f55dSAdam Siemieniuk projectedPos.push_back(idx); 730a945f55dSAdam Siemieniuk break; 731a945f55dSAdam Siemieniuk } 732a945f55dSAdam Siemieniuk } 733a945f55dSAdam Siemieniuk } 734a945f55dSAdam Siemieniuk assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection"); 735a945f55dSAdam Siemieniuk 736a945f55dSAdam Siemieniuk return projectedPos; 737a945f55dSAdam Siemieniuk } 738a945f55dSAdam Siemieniuk 739a945f55dSAdam Siemieniuk /// Bubble up pack op through expand shape op. 740a945f55dSAdam Siemieniuk /// 741a945f55dSAdam Siemieniuk /// For example: 742a945f55dSAdam Siemieniuk /// 743a945f55dSAdam Siemieniuk /// %expand = tensor.expand_shape %in [[0], [1, 2]] 744a945f55dSAdam Siemieniuk /// : tensor<?x64xf32> into tensor<?x4x16xf32> 745a945f55dSAdam Siemieniuk /// %pack = tensor.pack %expand outer_dims_perm = [0, 1] 746a945f55dSAdam Siemieniuk /// inner_dims_pos = [2] inner_tiles = [8] into %empty 747a945f55dSAdam Siemieniuk /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> 748a945f55dSAdam Siemieniuk /// 749a945f55dSAdam Siemieniuk /// can be transformed into: 750a945f55dSAdam Siemieniuk /// 751a945f55dSAdam Siemieniuk /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 752a945f55dSAdam Siemieniuk /// inner_dims_pos = [1] inner_tiles = [8] into %empty 753a945f55dSAdam Siemieniuk /// : tensor<?x64xf32> -> tensor<?x8x8xf32> 754a945f55dSAdam Siemieniuk /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]] 755a945f55dSAdam Siemieniuk /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> 756a945f55dSAdam Siemieniuk static LogicalResult 757a945f55dSAdam Siemieniuk bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, 758a945f55dSAdam Siemieniuk tensor::PackOp packOp, 759a945f55dSAdam Siemieniuk PatternRewriter &rewriter) { 760a945f55dSAdam Siemieniuk // Outer dimensions permutation is not supported currently. 761a945f55dSAdam Siemieniuk // TODO: Handle outer_dims_perm variants. 762a945f55dSAdam Siemieniuk ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 763a945f55dSAdam Siemieniuk if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { 764a945f55dSAdam Siemieniuk return rewriter.notifyMatchFailure(packOp, 765a945f55dSAdam Siemieniuk "non-identity outer dims perm NYI"); 766a945f55dSAdam Siemieniuk } 767a945f55dSAdam Siemieniuk 768a945f55dSAdam Siemieniuk // Validate dimensions' relations between shape expansion and packing. 769a945f55dSAdam Siemieniuk SmallVector<ReassociationIndices, 4> reassoc = 770a945f55dSAdam Siemieniuk expandOp.getReassociationIndices(); 771a945f55dSAdam Siemieniuk ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos(); 772a945f55dSAdam Siemieniuk llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(), 773a945f55dSAdam Siemieniuk packInnerDims.end()); 774a945f55dSAdam Siemieniuk 775a945f55dSAdam Siemieniuk for (auto [idx, indices] : llvm::enumerate(reassoc)) { 776a945f55dSAdam Siemieniuk // For each expand_shape reassociation, figure out which dimensions get 777a945f55dSAdam Siemieniuk // packed if any. 778a945f55dSAdam Siemieniuk llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end()); 779a945f55dSAdam Siemieniuk llvm::SetVector<int64_t> packedDims = 780a945f55dSAdam Siemieniuk llvm::set_intersection(packDimsPos, expandDimPos); 781a945f55dSAdam Siemieniuk 782a945f55dSAdam Siemieniuk // The expanded dimension is not packed so, it does not affect moving pack 783a945f55dSAdam Siemieniuk // before shape expansion - simply continue. 784a945f55dSAdam Siemieniuk if (packedDims.empty()) 785a945f55dSAdam Siemieniuk continue; 786a945f55dSAdam Siemieniuk // Shape expansion cannot be propagated when multiple expanded dimension are 787a945f55dSAdam Siemieniuk // packed - in this case operation reordering would affect final element 788a945f55dSAdam Siemieniuk // positions and/or shapes can no longer be projected. 789a945f55dSAdam Siemieniuk if (packedDims.size() != 1) 790a945f55dSAdam Siemieniuk return rewriter.notifyMatchFailure( 791a945f55dSAdam Siemieniuk packOp, "only one of the expanded dimensions can be packed"); 792a945f55dSAdam Siemieniuk // Only the inner-most expanded dimension should be packed. Otherwise, 793a945f55dSAdam Siemieniuk // elements order will be affected after operation reordering. 794a945f55dSAdam Siemieniuk if (packedDims.front() != indices.back()) 795a945f55dSAdam Siemieniuk return rewriter.notifyMatchFailure( 796a945f55dSAdam Siemieniuk packOp, "can only pack the inner-most expanded dimension"); 797a945f55dSAdam Siemieniuk } 798a945f55dSAdam Siemieniuk 799a945f55dSAdam Siemieniuk // Project pack.inner_dims_pos to positions before shape expansion. 800a945f55dSAdam Siemieniuk SmallVector<int64_t> projectedInnerDimsPos = 801a945f55dSAdam Siemieniuk projectDimsPosIntoReassocPos(packInnerDims, reassoc); 802a945f55dSAdam Siemieniuk 803a945f55dSAdam Siemieniuk // Project the shape expansion to new packed shape. 804a945f55dSAdam Siemieniuk // The pack.outer_dims_perm is restricted to identity so, the permutation can 805a945f55dSAdam Siemieniuk // be omitted for simplicity. 806a945f55dSAdam Siemieniuk // TODO: Account for outer dimensions permutation. 807a945f55dSAdam Siemieniuk // 808a945f55dSAdam Siemieniuk // If reassociation is not possible, then reordering cannot happen. 809a945f55dSAdam Siemieniuk // This can be caused by pack padding affecting previously expanded 810a945f55dSAdam Siemieniuk // dimensions or packing extending dimensions. 811a945f55dSAdam Siemieniuk RankedTensorType newPackType = tensor::PackOp::inferPackedType( 812a945f55dSAdam Siemieniuk expandOp.getSrcType(), packOp.getStaticInnerTiles(), 813a945f55dSAdam Siemieniuk projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); 814a945f55dSAdam Siemieniuk auto reassocExpand = 815a945f55dSAdam Siemieniuk getReassociationIndicesForReshape(newPackType, packOp.getDestType()); 816a945f55dSAdam Siemieniuk if (!reassocExpand) 817a945f55dSAdam Siemieniuk return rewriter.notifyMatchFailure( 818a945f55dSAdam Siemieniuk packOp, "could not reassociate dims after bubbling up"); 819a945f55dSAdam Siemieniuk 820a945f55dSAdam Siemieniuk Value destTensor = tensor::PackOp::createDestinationTensor( 821a945f55dSAdam Siemieniuk rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), 822a945f55dSAdam Siemieniuk projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); 823a945f55dSAdam Siemieniuk Value packedVal = rewriter.create<tensor::PackOp>( 824a945f55dSAdam Siemieniuk packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, 825a945f55dSAdam Siemieniuk packOp.getMixedTiles(), packOp.getPaddingValue(), 826a945f55dSAdam Siemieniuk /*outerDimsPerm=*/SmallVector<int64_t>{}); 827a945f55dSAdam Siemieniuk 828a945f55dSAdam Siemieniuk Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 829a945f55dSAdam Siemieniuk packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); 830a945f55dSAdam Siemieniuk rewriter.replaceOp(packOp, newExpandOp); 831a945f55dSAdam Siemieniuk 832a945f55dSAdam Siemieniuk return success(); 833a945f55dSAdam Siemieniuk } 834a945f55dSAdam Siemieniuk 8350c1c0d53SJerry Wu class BubbleUpPackOpThroughReshapeOp final 8360c1c0d53SJerry Wu : public OpRewritePattern<tensor::PackOp> { 8370c1c0d53SJerry Wu public: 8380c1c0d53SJerry Wu BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun) 8390c1c0d53SJerry Wu : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 8400c1c0d53SJerry Wu 8410c1c0d53SJerry Wu LogicalResult matchAndRewrite(tensor::PackOp packOp, 8420c1c0d53SJerry Wu PatternRewriter &rewriter) const override { 8430c1c0d53SJerry Wu Operation *srcOp = packOp.getSource().getDefiningOp(); 8440c1c0d53SJerry Wu // Currently only support when the pack op is the only user. 8450c1c0d53SJerry Wu if (!srcOp || !(srcOp->getNumResults() == 1) || 8460c1c0d53SJerry Wu !srcOp->getResult(0).hasOneUse()) { 8470c1c0d53SJerry Wu return failure(); 8480c1c0d53SJerry Wu } 8490c1c0d53SJerry Wu // Currently only support static inner tile sizes. 8500c1c0d53SJerry Wu if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) { 8510c1c0d53SJerry Wu return ShapedType::isDynamic(size); 8520c1c0d53SJerry Wu })) { 8530c1c0d53SJerry Wu return failure(); 8540c1c0d53SJerry Wu } 8550c1c0d53SJerry Wu 8560c1c0d53SJerry Wu // User controlled propagation function. 85704fc471fSHan-Chung Wang if (!controlFn(&packOp.getSourceMutable())) 8580c1c0d53SJerry Wu return failure(); 8590c1c0d53SJerry Wu 8600c1c0d53SJerry Wu return TypeSwitch<Operation *, LogicalResult>(srcOp) 8610c1c0d53SJerry Wu .Case([&](tensor::CollapseShapeOp op) { 8620c1c0d53SJerry Wu return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); 8630c1c0d53SJerry Wu }) 864a945f55dSAdam Siemieniuk .Case([&](tensor::ExpandShapeOp op) { 865a945f55dSAdam Siemieniuk return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter); 866a945f55dSAdam Siemieniuk }) 8670c1c0d53SJerry Wu .Default([](Operation *) { return failure(); }); 8680c1c0d53SJerry Wu } 8690c1c0d53SJerry Wu 8700c1c0d53SJerry Wu private: 8710c1c0d53SJerry Wu ControlPropagationFn controlFn; 8720c1c0d53SJerry Wu }; 8730c1c0d53SJerry Wu 8740c1c0d53SJerry Wu /// Push down unpack op through expand shape op when the packed dims can be 8750c1c0d53SJerry Wu /// projected to the dims after expanding. This is possible when the inner tile 8760c1c0d53SJerry Wu /// sizes can divide the projected dims. 8770c1c0d53SJerry Wu /// 8780c1c0d53SJerry Wu /// For example: 8790c1c0d53SJerry Wu /// 8800c1c0d53SJerry Wu /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1] 8810c1c0d53SJerry Wu /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty 8820c1c0d53SJerry Wu /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32> 8830c1c0d53SJerry Wu /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]] 8840c1c0d53SJerry Wu /// : tensor<?x256xf32> into tensor<?x256x256xf32> 8850c1c0d53SJerry Wu /// 8860c1c0d53SJerry Wu /// can be transformed into: 8870c1c0d53SJerry Wu /// 8880c1c0d53SJerry Wu /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]] 8890c1c0d53SJerry Wu /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> 8900c1c0d53SJerry Wu /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2] 8910c1c0d53SJerry Wu /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty 8920c1c0d53SJerry Wu /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> 89304fc471fSHan-Chung Wang static LogicalResult pushDownUnPackOpThroughExpandShape( 89404fc471fSHan-Chung Wang tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp, 89504fc471fSHan-Chung Wang PatternRewriter &rewriter, ControlPropagationFn controlFn) { 89604fc471fSHan-Chung Wang // User controlled propagation function. 89704fc471fSHan-Chung Wang if (!controlFn(&expandOp.getSrcMutable())) 89804fc471fSHan-Chung Wang return failure(); 89904fc471fSHan-Chung Wang 9000c1c0d53SJerry Wu SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles(); 9010c1c0d53SJerry Wu ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos(); 9020c1c0d53SJerry Wu ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm(); 9030c1c0d53SJerry Wu 904d2353695SPeiming Liu auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType()); 90597069a86SGaurav Shukla if (!expandTy) 90697069a86SGaurav Shukla return failure(); 90797069a86SGaurav Shukla ArrayRef<int64_t> dstShape = expandTy.getShape(); 9080c1c0d53SJerry Wu SmallVector<ReassociationIndices> reassocIndices = 9090c1c0d53SJerry Wu expandOp.getReassociationIndices(); 9100c1c0d53SJerry Wu // Project inner tile pos to the dim pos after expanding. For example, if dims 9110c1c0d53SJerry Wu // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack 9120c1c0d53SJerry Wu // on dim y. 9130c1c0d53SJerry Wu // 9140c1c0d53SJerry Wu // Project to inner-most non-unit dims to increase the chance that they can be 9150c1c0d53SJerry Wu // divided by the inner tile sizes. This is correct because for [..., x, 1], 9160c1c0d53SJerry Wu // unpacking on dim 1 is equivalent to unpacking on dim x. 9170c1c0d53SJerry Wu SmallVector<int64_t> projectedInnerDimsPos = 9180c1c0d53SJerry Wu projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape); 9190c1c0d53SJerry Wu 9200c1c0d53SJerry Wu if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape, 9210c1c0d53SJerry Wu innerTileSizes)) { 9220c1c0d53SJerry Wu return failure(); 9230c1c0d53SJerry Wu } 9240c1c0d53SJerry Wu // Expand the outer dims permutation with the associated expanded dims for the 9250c1c0d53SJerry Wu // new permutation after pushing. This is because moving a source dim is 9260c1c0d53SJerry Wu // equivalent to moving the associated expanded dims together. 9270c1c0d53SJerry Wu SmallVector<int64_t> newOuterDimsPerm; 9280c1c0d53SJerry Wu for (auto outerPos : outerDimsPerm) { 9290c1c0d53SJerry Wu newOuterDimsPerm.insert(newOuterDimsPerm.end(), 9300c1c0d53SJerry Wu reassocIndices[outerPos].begin(), 9310c1c0d53SJerry Wu reassocIndices[outerPos].end()); 9320c1c0d53SJerry Wu } 9330c1c0d53SJerry Wu 9340c1c0d53SJerry Wu SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 9350c1c0d53SJerry Wu // First apply the permutation on the reassociations of the outer dims. 9360c1c0d53SJerry Wu // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 9370c1c0d53SJerry Wu // -> [[0], [1, 2]] 9380c1c0d53SJerry Wu int64_t nextPos = 9390c1c0d53SJerry Wu applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 9400c1c0d53SJerry Wu // Then add direct mapping for the inner tile dims. 9410c1c0d53SJerry Wu for (size_t i = 0; i < innerDimsPos.size(); ++i) { 9420c1c0d53SJerry Wu newReassocIndices.push_back({nextPos}); 9430c1c0d53SJerry Wu nextPos += 1; 9440c1c0d53SJerry Wu } 9450c1c0d53SJerry Wu 94697069a86SGaurav Shukla RankedTensorType newExpandType = tensor::PackOp::inferPackedType( 94797069a86SGaurav Shukla expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); 9480c1c0d53SJerry Wu auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 9490c1c0d53SJerry Wu expandOp.getLoc(), newExpandType, unPackOp.getSource(), 9500c1c0d53SJerry Wu newReassocIndices); 9510c1c0d53SJerry Wu 9520c1c0d53SJerry Wu auto emptyOp = tensor::UnPackOp::createDestinationTensor( 9530c1c0d53SJerry Wu rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), 9540c1c0d53SJerry Wu projectedInnerDimsPos, newOuterDimsPerm); 9550c1c0d53SJerry Wu auto newUnPackOp = rewriter.create<tensor::UnPackOp>( 9560c1c0d53SJerry Wu unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, 9570c1c0d53SJerry Wu projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); 9580c1c0d53SJerry Wu rewriter.replaceOp(expandOp, newUnPackOp); 9590c1c0d53SJerry Wu 9600c1c0d53SJerry Wu return success(); 9610c1c0d53SJerry Wu } 9620c1c0d53SJerry Wu 9630c1c0d53SJerry Wu class PushDownUnPackOpThroughReshapeOp final 9640c1c0d53SJerry Wu : public OpRewritePattern<tensor::UnPackOp> { 9650c1c0d53SJerry Wu public: 9660c1c0d53SJerry Wu PushDownUnPackOpThroughReshapeOp(MLIRContext *context, 9670c1c0d53SJerry Wu ControlPropagationFn fun) 9680c1c0d53SJerry Wu : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) { 9690c1c0d53SJerry Wu } 9700c1c0d53SJerry Wu 9710c1c0d53SJerry Wu LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp, 9720c1c0d53SJerry Wu PatternRewriter &rewriter) const override { 9730c1c0d53SJerry Wu Value result = unPackOp.getResult(); 9740c1c0d53SJerry Wu // Currently only support unpack op with the single user. 9750c1c0d53SJerry Wu if (!result.hasOneUse()) { 9760c1c0d53SJerry Wu return failure(); 9770c1c0d53SJerry Wu } 9780c1c0d53SJerry Wu // Currently only support static inner tile sizes. 9790c1c0d53SJerry Wu if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) { 9800c1c0d53SJerry Wu return ShapedType::isDynamic(size); 9810c1c0d53SJerry Wu })) { 9820c1c0d53SJerry Wu return failure(); 9830c1c0d53SJerry Wu } 9840c1c0d53SJerry Wu 9850c1c0d53SJerry Wu Operation *consumerOp = *result.user_begin(); 9860c1c0d53SJerry Wu return TypeSwitch<Operation *, LogicalResult>(consumerOp) 9870c1c0d53SJerry Wu .Case([&](tensor::ExpandShapeOp op) { 98804fc471fSHan-Chung Wang return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter, 98904fc471fSHan-Chung Wang controlFn); 9900c1c0d53SJerry Wu }) 9910c1c0d53SJerry Wu .Default([](Operation *) { return failure(); }); 9920c1c0d53SJerry Wu } 9930c1c0d53SJerry Wu 9940c1c0d53SJerry Wu private: 9950c1c0d53SJerry Wu ControlPropagationFn controlFn; 9960c1c0d53SJerry Wu }; 9970c1c0d53SJerry Wu 9989f242404SLorenzo Chelini // TODO: Relax this restriction. We should unpack a generic op also 9996bb0ab0dSLorenzo Chelini // in the presence of multiple unpack ops as producers. 10006bb0ab0dSLorenzo Chelini /// Return the unpacked operand, if present, for the current generic op. 10016bb0ab0dSLorenzo Chelini static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { 10026bb0ab0dSLorenzo Chelini OpOperand *unPackedOperand = nullptr; 10036bb0ab0dSLorenzo Chelini for (OpOperand &operand : genericOp->getOpOperands()) { 10046bb0ab0dSLorenzo Chelini auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>(); 10056bb0ab0dSLorenzo Chelini if (!unPackOp) 10066bb0ab0dSLorenzo Chelini continue; 10076bb0ab0dSLorenzo Chelini if (unPackedOperand) 10086bb0ab0dSLorenzo Chelini return failure(); 10096bb0ab0dSLorenzo Chelini unPackedOperand = &operand; 10106bb0ab0dSLorenzo Chelini } 10116bb0ab0dSLorenzo Chelini if (!unPackedOperand) 10126bb0ab0dSLorenzo Chelini return failure(); 10136bb0ab0dSLorenzo Chelini return unPackedOperand; 10146bb0ab0dSLorenzo Chelini } 10156bb0ab0dSLorenzo Chelini 10169f242404SLorenzo Chelini /// Push down a tensor.unpack op through a generic op. 10176bb0ab0dSLorenzo Chelini /// The new generic op works on packed domain; pack ops are created for input 10186bb0ab0dSLorenzo Chelini /// and output operands. A tensor.unpack op is inserted right after the packed 10196bb0ab0dSLorenzo Chelini /// generic. E.g. 10206bb0ab0dSLorenzo Chelini /// 10216bb0ab0dSLorenzo Chelini /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 10226bb0ab0dSLorenzo Chelini /// 10236bb0ab0dSLorenzo Chelini /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. 10246bb0ab0dSLorenzo Chelini /// 10256bb0ab0dSLorenzo Chelini /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 10266bb0ab0dSLorenzo Chelini /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] 10276bb0ab0dSLorenzo Chelini /// inner_dims_pos = [3] inner_tiles = [32] into %0 10286bb0ab0dSLorenzo Chelini /// %2 = linalg.generic {indexing_maps = [#map], 10296bb0ab0dSLorenzo Chelini /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 10306bb0ab0dSLorenzo Chelini /// outs(%1 : tensor<12x56x56x64xf32>) { 10316bb0ab0dSLorenzo Chelini /// ^bb0(%out : f32): 10326bb0ab0dSLorenzo Chelini /// linalg.yield %out : f32 10336bb0ab0dSLorenzo Chelini /// } -> tensor<12x56x56x64xf32> 10346bb0ab0dSLorenzo Chelini /// 10356bb0ab0dSLorenzo Chelini /// will be converted to 10366bb0ab0dSLorenzo Chelini /// 10376bb0ab0dSLorenzo Chelini /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 10386bb0ab0dSLorenzo Chelini /// 10396bb0ab0dSLorenzo Chelini /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 10406bb0ab0dSLorenzo Chelini /// %1 = linalg.generic {indexing_maps = [#map], 10416bb0ab0dSLorenzo Chelini /// iterator_types = ["parallel", "parallel", "parallel", 10426bb0ab0dSLorenzo Chelini /// "parallel", "parallel"]} 10436bb0ab0dSLorenzo Chelini /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { 10446bb0ab0dSLorenzo Chelini /// ^bb0(%out : f32): 10456bb0ab0dSLorenzo Chelini /// linalg.yield %out : f32 10466bb0ab0dSLorenzo Chelini /// } -> tensor<12x2x56x56x32xf32> 10476bb0ab0dSLorenzo Chelini /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2] 10486bb0ab0dSLorenzo Chelini /// inner_dims_pos = [3] inner_tiles = [32] into %0 10496bb0ab0dSLorenzo Chelini /// 10506bb0ab0dSLorenzo Chelini static FailureOr<std::tuple<GenericOp, Value>> 105104fc471fSHan-Chung Wang pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, 105204fc471fSHan-Chung Wang ControlPropagationFn controlFn) { 10536bb0ab0dSLorenzo Chelini if (genericOp.getNumResults() != 1) 10546bb0ab0dSLorenzo Chelini return failure(); 10556bb0ab0dSLorenzo Chelini 1056b4563ee1SQuinn Dawkins if (hasGatherSemantics(genericOp)) 1057b4563ee1SQuinn Dawkins return failure(); 1058b4563ee1SQuinn Dawkins 10596bb0ab0dSLorenzo Chelini // Collect the unPacked operand, if present. 10606bb0ab0dSLorenzo Chelini auto maybeUnPackedOperand = getUnPackedOperand(genericOp); 10616bb0ab0dSLorenzo Chelini if (failed(maybeUnPackedOperand)) 10626bb0ab0dSLorenzo Chelini return failure(); 10636bb0ab0dSLorenzo Chelini OpOperand *unPackedOperand = *(maybeUnPackedOperand); 10646bb0ab0dSLorenzo Chelini 10656bb0ab0dSLorenzo Chelini // Extract packing information. 10666bb0ab0dSLorenzo Chelini tensor::UnPackOp producerUnPackOp = 10676bb0ab0dSLorenzo Chelini unPackedOperand->get().getDefiningOp<tensor::UnPackOp>(); 10686bb0ab0dSLorenzo Chelini assert(producerUnPackOp && "expect a valid UnPackOp"); 106904fc471fSHan-Chung Wang 107004fc471fSHan-Chung Wang if (!controlFn(unPackedOperand)) 107104fc471fSHan-Chung Wang return failure(); 107204fc471fSHan-Chung Wang 1073b4563ee1SQuinn Dawkins auto packInfo = 1074b4563ee1SQuinn Dawkins getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); 1075b4563ee1SQuinn Dawkins if (failed(packInfo)) 1076b4563ee1SQuinn Dawkins return failure(); 10776bb0ab0dSLorenzo Chelini 10786bb0ab0dSLorenzo Chelini // Rebuild the indexing map for the corresponding init operand. 10796bb0ab0dSLorenzo Chelini auto [packedOutOperand, packedOutIndexingMap] = 1080b4563ee1SQuinn Dawkins getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 10816bb0ab0dSLorenzo Chelini genericOp, genericOp.getDpsInitOperand(0)); 1082b4563ee1SQuinn Dawkins auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>(); 10836bb0ab0dSLorenzo Chelini 10846bb0ab0dSLorenzo Chelini // If the dps init operand of the generic is a tensor.empty, do not pack it 10856bb0ab0dSLorenzo Chelini // and forward the new tensor.empty as a destination. 10866bb0ab0dSLorenzo Chelini Value dest = packedOutOperand; 10876bb0ab0dSLorenzo Chelini if (auto initTensor = genericOp.getDpsInitOperand(0) 10886bb0ab0dSLorenzo Chelini ->get() 10896bb0ab0dSLorenzo Chelini .getDefiningOp<tensor::EmptyOp>()) { 1090b4563ee1SQuinn Dawkins if (destPack) 1091b4563ee1SQuinn Dawkins dest = destPack.getDest(); 10926bb0ab0dSLorenzo Chelini } 10936bb0ab0dSLorenzo Chelini 10946bb0ab0dSLorenzo Chelini // Pack the genericOp. 10959f242404SLorenzo Chelini GenericOp newGenericOp = 10969f242404SLorenzo Chelini packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); 1097b4563ee1SQuinn Dawkins Value newResult = 1098b4563ee1SQuinn Dawkins newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); 10996bb0ab0dSLorenzo Chelini 1100b4563ee1SQuinn Dawkins // If the output is unaffected, no need to unpack. 1101b4563ee1SQuinn Dawkins if (!destPack) 1102b4563ee1SQuinn Dawkins return std::make_tuple(newGenericOp, newResult); 1103b4563ee1SQuinn Dawkins 1104b4563ee1SQuinn Dawkins auto mixedTiles = destPack.getMixedTiles(); 1105b4563ee1SQuinn Dawkins auto innerDimsPos = destPack.getInnerDimsPos(); 1106b4563ee1SQuinn Dawkins auto outerDimsPerm = destPack.getOuterDimsPerm(); 1107b4563ee1SQuinn Dawkins 11086bb0ab0dSLorenzo Chelini // Insert an unPackOp right after the packed generic. 11096bb0ab0dSLorenzo Chelini Value unPackOpRes = 11106bb0ab0dSLorenzo Chelini rewriter 1111536486fbSAbhishek Varma .create<tensor::UnPackOp>(genericOp.getLoc(), newResult, 1112536486fbSAbhishek Varma destPack.getSource(), innerDimsPos, 1113b4563ee1SQuinn Dawkins mixedTiles, outerDimsPerm) 11146bb0ab0dSLorenzo Chelini .getResult(); 11156bb0ab0dSLorenzo Chelini 11166bb0ab0dSLorenzo Chelini return std::make_tuple(newGenericOp, unPackOpRes); 11176bb0ab0dSLorenzo Chelini } 11186bb0ab0dSLorenzo Chelini 1119b4563ee1SQuinn Dawkins // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. 1120b4563ee1SQuinn Dawkins struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { 1121b4563ee1SQuinn Dawkins public: 1122b4563ee1SQuinn Dawkins PushDownUnPackOpThroughGenericOp(MLIRContext *context, 1123b4563ee1SQuinn Dawkins ControlPropagationFn fun) 1124b4563ee1SQuinn Dawkins : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} 11256bb0ab0dSLorenzo Chelini 11266bb0ab0dSLorenzo Chelini LogicalResult matchAndRewrite(GenericOp genericOp, 11276bb0ab0dSLorenzo Chelini PatternRewriter &rewriter) const override { 112804fc471fSHan-Chung Wang auto genericAndRepl = 112904fc471fSHan-Chung Wang pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn); 11306bb0ab0dSLorenzo Chelini if (failed(genericAndRepl)) 11316bb0ab0dSLorenzo Chelini return failure(); 11326bb0ab0dSLorenzo Chelini rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); 11336bb0ab0dSLorenzo Chelini return success(); 11346bb0ab0dSLorenzo Chelini } 1135b4563ee1SQuinn Dawkins 1136b4563ee1SQuinn Dawkins private: 1137b4563ee1SQuinn Dawkins ControlPropagationFn controlFn; 11386bb0ab0dSLorenzo Chelini }; 11396bb0ab0dSLorenzo Chelini 114030d542f9SLorenzo Chelini /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to 114130d542f9SLorenzo Chelini /// add as many zero padding dimensions in `high` and `low` based on the number 114230d542f9SLorenzo Chelini /// of point loops. 114330d542f9SLorenzo Chelini struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { 1144b4563ee1SQuinn Dawkins PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 1145b4563ee1SQuinn Dawkins : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} 114630d542f9SLorenzo Chelini 114730d542f9SLorenzo Chelini LogicalResult matchAndRewrite(tensor::PadOp padOp, 114830d542f9SLorenzo Chelini PatternRewriter &rewriter) const override { 114930d542f9SLorenzo Chelini tensor::UnPackOp unpackOp = 115030d542f9SLorenzo Chelini padOp.getSource().getDefiningOp<tensor::UnPackOp>(); 115130d542f9SLorenzo Chelini if (!unpackOp) 115230d542f9SLorenzo Chelini return failure(); 115330d542f9SLorenzo Chelini 115404fc471fSHan-Chung Wang if (!controlFn(&padOp.getSourceMutable())) 1155b4563ee1SQuinn Dawkins return failure(); 1156b4563ee1SQuinn Dawkins 115730d542f9SLorenzo Chelini Location loc = padOp.getLoc(); 115830d542f9SLorenzo Chelini // Bail out if one of the padded dimension is a tiled one. 115930d542f9SLorenzo Chelini llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 116030d542f9SLorenzo Chelini ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); 116130d542f9SLorenzo Chelini llvm::SmallBitVector innerDims(paddedDims.size()); 116230d542f9SLorenzo Chelini for (int64_t dim : innerDimsPos) 116330d542f9SLorenzo Chelini innerDims.flip(dim); 116430d542f9SLorenzo Chelini if (paddedDims.anyCommon(innerDims)) 116530d542f9SLorenzo Chelini return failure(); 116630d542f9SLorenzo Chelini 116730d542f9SLorenzo Chelini Value paddingVal = padOp.getConstantPaddingValue(); 116830d542f9SLorenzo Chelini if (!paddingVal) 116930d542f9SLorenzo Chelini return failure(); 117030d542f9SLorenzo Chelini 117130d542f9SLorenzo Chelini // If we have `outer_dims_perms` we need to adjust the padded dimensions. 117230d542f9SLorenzo Chelini ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); 117330d542f9SLorenzo Chelini SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 117430d542f9SLorenzo Chelini SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 117530d542f9SLorenzo Chelini if (!outerDimsPerm.empty()) { 117630d542f9SLorenzo Chelini applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 117730d542f9SLorenzo Chelini applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 117830d542f9SLorenzo Chelini } 117930d542f9SLorenzo Chelini // Add zero padding for the point loops. 118030d542f9SLorenzo Chelini size_t pointLoopsSize = innerDimsPos.size(); 118130d542f9SLorenzo Chelini lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 118230d542f9SLorenzo Chelini highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 118330d542f9SLorenzo Chelini 118430d542f9SLorenzo Chelini auto newPadOp = rewriter.create<tensor::PadOp>( 118530d542f9SLorenzo Chelini loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, 118630d542f9SLorenzo Chelini paddingVal, padOp.getNofold()); 118730d542f9SLorenzo Chelini 118830d542f9SLorenzo Chelini // Inject the tensor.unpack right after the packed padOp. 118930d542f9SLorenzo Chelini Value outputUnPack = rewriter.create<tensor::EmptyOp>( 119030d542f9SLorenzo Chelini loc, padOp.getResultType().getShape(), 119130d542f9SLorenzo Chelini padOp.getResultType().getElementType()); 119230d542f9SLorenzo Chelini 119330d542f9SLorenzo Chelini Value replacement = rewriter.create<tensor::UnPackOp>( 119430d542f9SLorenzo Chelini loc, newPadOp.getResult(), outputUnPack, innerDimsPos, 119530d542f9SLorenzo Chelini unpackOp.getMixedTiles(), outerDimsPerm); 119630d542f9SLorenzo Chelini rewriter.replaceOp(padOp, replacement); 119730d542f9SLorenzo Chelini return success(); 119830d542f9SLorenzo Chelini } 1199b4563ee1SQuinn Dawkins 1200b4563ee1SQuinn Dawkins private: 1201b4563ee1SQuinn Dawkins ControlPropagationFn controlFn; 120230d542f9SLorenzo Chelini }; 120330d542f9SLorenzo Chelini 12040f297cadSHanhan Wang } // namespace 12050f297cadSHanhan Wang 12060f297cadSHanhan Wang void mlir::linalg::populateDataLayoutPropagationPatterns( 1207b4563ee1SQuinn Dawkins RewritePatternSet &patterns, 1208b4563ee1SQuinn Dawkins const ControlPropagationFn &controlPackUnPackPropagation) { 1209886294a2SQuinn Dawkins patterns 1210886294a2SQuinn Dawkins .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp, 12110c1c0d53SJerry Wu BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, 12120c1c0d53SJerry Wu PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( 1213b4563ee1SQuinn Dawkins patterns.getContext(), controlPackUnPackPropagation); 12140f297cadSHanhan Wang } 1215