1 //===- DataLayoutPropagation.cpp -----------------------------------------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Tensor/Utils/Utils.h" 17 #include "mlir/Dialect/Utils/IndexingUtils.h" 18 #include "mlir/IR/Dominance.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/Support/Debug.h" 21 #include <optional> 22 23 namespace mlir { 24 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION 25 #include "mlir/Dialect/Linalg/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 31 #define DEBUG_TYPE "linalg-data-layout-propagation" 32 33 namespace { 34 35 static bool hasGatherSemantics(linalg::GenericOp genericOp) { 36 for (Operation &op : genericOp.getBody()->getOperations()) 37 if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) 38 return true; 39 return false; 40 } 41 42 // The struct contains the infomation about mapping packing information to 43 // the iteration domain of Linalg ops. 44 struct PackInfo { 45 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; 46 // InnerDimsPos on iteration domain, which follows the order in pack ops. 47 SmallVector<int64_t> tiledDimsPos; 48 // The sizes of tiling data dimensions on iteration domain. 49 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; 50 // The mapping from a dimension of iteration domain to the corresponding inner 51 // tiling dimension on iteration domain. 52 llvm::DenseMap<int64_t, int64_t> tileToPointMapping; 53 // The permutation of outer dims (on domain). 54 SmallVector<int64_t> outerDimsOnDomainPerm; 55 }; 56 57 template <typename OpTy> 58 static FailureOr<PackInfo> 59 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, 60 OpTy packOrUnPackOp) { 61 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value, 62 "applies to only pack or unpack operations"); 63 LLVM_DEBUG( 64 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); 65 66 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); 67 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 68 SmallVector<utils::IteratorType> iterators = 69 genericOp.getIteratorTypesArray(); 70 71 PackInfo packInfo; 72 int64_t origNumDims = indexingMap.getNumDims(); 73 SmallVector<AffineExpr> exprs(indexingMap.getResults()); 74 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); 75 for (auto [index, innerDimPos, tileSize] : 76 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), 77 innerDimsPos, packOrUnPackOp.getMixedTiles())) { 78 auto expr = exprs[innerDimPos]; 79 if (!expr.template isa<AffineDimExpr>()) 80 return failure(); 81 int64_t domainDimPos = 82 exprs[innerDimPos].template cast<AffineDimExpr>().getPosition(); 83 if (!isParallelIterator(iterators[domainDimPos])) 84 return failure(); 85 packInfo.tiledDimsPos.push_back(domainDimPos); 86 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; 87 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; 88 LLVM_DEBUG({ 89 llvm::dbgs() << "map innerDimPos=" << innerDimPos 90 << " to iteration dimension (d" << domainDimPos << ", d" 91 << packInfo.tileToPointMapping[domainDimPos] 92 << "), which has size=(" 93 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; 94 }); 95 } 96 97 // Bail out if a tiled dimension is present in a map but not as an affine dim 98 // expression. 99 auto areAllAffineDimExpr = [&](int dim) { 100 for (AffineMap map : indexingMaps) { 101 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { 102 return expr.isFunctionOfDim(dim) && !expr.isa<AffineDimExpr>(); 103 })) { 104 return false; 105 } 106 } 107 return true; 108 }; 109 for (int64_t i : packInfo.tiledDimsPos) 110 if (!areAllAffineDimExpr(i)) 111 return failure(); 112 113 // Get the outer dims perm on the iteration domain. Start by identifying the 114 // set of domain dims affected by the outer permutation along with the 115 // permuted ordering for those dims. Then the full outer dims permutation can 116 // be constructed by replacing the affected dims with the permuted result in a 117 // numLoops-rank identity. e.g. 118 // outerDimsPerm = [1, 2, 0] 119 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) 120 // 121 // permutedOuterDims = [4, 3, 1] 122 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] 123 // 124 // Non-affine dim expressions must not be permuted by the outer dims 125 // permutation. 126 SmallVector<int64_t> permutedOuterDims; 127 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { 128 auto permutedExpr = indexingMap.getResult(dim); 129 if (auto dimExpr = permutedExpr.template dyn_cast<AffineDimExpr>()) { 130 permutedOuterDims.push_back(dimExpr.getPosition()); 131 continue; 132 } 133 134 // TODO: Allow propagation with transposes on non affine dim expressions, 135 // e.g. d0 + d1 which implies transposing both dims simultaneously while 136 // maintaining the relative position between them. 137 if (static_cast<int64_t>(index) != dim) 138 return failure(); 139 } 140 if (!permutedOuterDims.empty()) { 141 int64_t outerDimIndex = 0; 142 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), 143 permutedOuterDims.end()); 144 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) 145 packInfo.outerDimsOnDomainPerm.push_back( 146 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] 147 : i); 148 LLVM_DEBUG({ 149 llvm::dbgs() << "map outer dimsDimsPerm to "; 150 for (auto dim : packInfo.outerDimsOnDomainPerm) 151 llvm::dbgs() << dim << " "; 152 llvm::dbgs() << "\n"; 153 }); 154 } 155 156 return packInfo; 157 } 158 159 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, 160 ArrayRef<AffineExpr> exprs) { 161 // Compute `outer_dims_perm`. See example: 162 // current exprs : (d0, d1, d2, d3) -> (d2, d3) 163 // perm : [0, 3, 1, 2] 164 // First map d2, d3 with their position in the array as: 165 // currentPositionTileLoops: dim | pos 166 // d2 | 0 167 // d3 | 1 168 // then scan `perm` in order and get the `outer_dims_perm` 169 // to be used, here it would be [1, 0]. 170 assert(!perm.empty() && "expect perm not to be empty"); 171 assert(!exprs.empty() && "expect exprs not to be empty"); 172 if (exprs.size() == 1) 173 return {}; 174 SmallVector<int64_t> outerDimsPerm; 175 DenseMap<int64_t, int64_t> currentPositionTileLoops; 176 for (auto [pos, expr] : llvm::enumerate(exprs)) { 177 // Here we rely on the assumption that the outer dims permutation 178 // when propagating currently requires that non-affine dim expressions 179 // are not permuted, thus allowing the identity assignment below. 180 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) 181 currentPositionTileLoops[dimExpr.getPosition()] = pos; 182 else 183 currentPositionTileLoops[pos] = pos; 184 } 185 for (int64_t loopIdx : perm) { 186 if (currentPositionTileLoops.count(loopIdx)) 187 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); 188 } 189 return outerDimsPerm; 190 } 191 192 /// Returns a tuple for packed operand and indexing_map with the assumptions: 193 /// 1) The generic op is the producer of the pack op. 194 /// 2) The generic op has only one result. 195 /// If the operand is a scalar or packing dimensions are all irrelevant to the 196 /// operand, the operand and the updated indexing map will be returned. 197 /// Otherwise, it returns the packed operand and the updated indexing map. E.g., 198 /// 199 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 200 /// #map1 = affine_map<(d0, d1) -> (d0)> 201 /// #map2 = affine_map<(d0, d1) -> (d1)> 202 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], 203 /// iterator_types = ["parallel", "parallel"]} 204 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) 205 /// outs(%init : tensor<?x?xf32>) { 206 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): 207 /// %4 = arith.addf %arg3, %arg4 : f32 208 /// linalg.yield %4 : f32 209 /// } -> tensor<?x?xf32> 210 /// %1 = tensor.pack %0 211 /// inner_dims_pos = [0, 1] 212 /// inner_tiles = [8, 2] 213 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 214 /// 215 /// Taking the first input operand as an example, the inner tile size of d1 is 216 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> 217 /// affine_map<(d1, d3)>` will be returned. 218 /// 219 /// %pack = tensor.pack %arg0 220 /// inner_dims_pos = [0] 221 /// inner_tiles = [8] 222 /// into %init : tensor<?xf32> -> tensor<?x8xf32> 223 static std::tuple<Value, AffineMap> 224 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, 225 GenericOp genericOp, OpOperand *opOperand) { 226 int64_t numOrigLoops = genericOp.getNumLoops(); 227 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 228 int64_t numLoops = numOrigLoops + numInnerLoops; 229 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); 230 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; 231 SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); 232 if (genericOp.isScalar(opOperand) || exprs.empty()) 233 return std::make_tuple(opOperand->get(), 234 AffineMap::get(numLoops, 0, exprs, b.getContext())); 235 236 // Step 1. Construct the information of packing data dimensions; append inner 237 // dimensions to the indexing maps for the operand. 238 for (auto [index, expr] : llvm::enumerate(exprs)) { 239 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) { 240 int64_t dimPos = dimExpr.getPosition(); 241 domainDimToOperandDim[dimPos] = index; 242 continue; 243 } 244 } 245 SmallVector<int64_t> innerDimsPos; 246 SmallVector<OpFoldResult> innerTileSizes; 247 for (auto dimPos : packInfo.tiledDimsPos) { 248 if (!domainDimToOperandDim.count(dimPos)) 249 continue; 250 int64_t index = domainDimToOperandDim[dimPos]; 251 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); 252 innerDimsPos.push_back(index); 253 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); 254 } 255 256 // Step 2. Handle outer dim permutations. 257 SmallVector<int64_t> outerDimsPerm; 258 if (!packInfo.outerDimsOnDomainPerm.empty()) { 259 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); 260 261 // Step 2.1: Fold transpose into the linalg.generic. 262 SmallVector<int64_t> inversedOuterPerm = 263 invertPermutationVector(packInfo.outerDimsOnDomainPerm); 264 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { 265 if (auto dimExpr = exprs[i].dyn_cast<AffineDimExpr>()) { 266 int64_t dimPos = dimExpr.getPosition(); 267 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); 268 continue; 269 } 270 assert(exprs[i].isa<AffineConstantExpr>() && 271 "Attempted to permute non-constant and non-affine dim expression"); 272 } 273 // Step 2.2: Undo the transposition on `exprs` and propagate the 274 // transposition on the pack using outerDimsPerm. 275 if (!outerDimsPerm.empty()) { 276 SmallVector<AffineExpr> auxVec = exprs; 277 for (const auto &en : enumerate(outerDimsPerm)) 278 auxVec[en.index()] = exprs[en.value()]; 279 exprs = auxVec; 280 } 281 } 282 auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); 283 284 // The operand does not have dimensions that relates to pack op. 285 if (innerDimsPos.empty() && outerDimsPerm.empty()) 286 return std::make_tuple(opOperand->get(), indexingMap); 287 288 auto empty = tensor::PackOp::createDestinationTensor( 289 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); 290 auto packedOperand = b.create<tensor::PackOp>( 291 loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, 292 /*padding=*/std::nullopt, outerDimsPerm); 293 return std::make_tuple(packedOperand, indexingMap); 294 } 295 296 /// Pack an element-wise genericOp and return it. 297 static GenericOp packElementWiseOp(RewriterBase &rewriter, GenericOp genericOp, 298 Value dest, AffineMap packedOutIndexingMap, 299 const PackInfo &packInfo) { 300 Location loc = genericOp.getLoc(); 301 SmallVector<Value> inputOperands; 302 SmallVector<AffineMap> indexingMaps; 303 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { 304 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( 305 rewriter, loc, packInfo, genericOp, inputOperand); 306 inputOperands.push_back(packedOperand); 307 indexingMaps.push_back(packedIndexingMap); 308 } 309 310 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 311 SmallVector<utils::IteratorType> iterTypes = 312 genericOp.getIteratorTypesArray(); 313 iterTypes.append(numInnerLoops, utils::IteratorType::parallel); 314 315 indexingMaps.push_back(packedOutIndexingMap); 316 317 auto newGenericOp = rewriter.create<linalg::GenericOp>( 318 loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, 319 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); 320 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), 321 newGenericOp.getRegion().begin()); 322 return newGenericOp; 323 } 324 325 /// Bubbles up tensor.pack op through a producer generic op. This 326 /// swap pack(generic) to generic(pack). The new generic op works on packed 327 /// domain; pack ops are created for input and output operands. E.g., 328 /// 329 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 330 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 331 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 332 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 333 /// %3 = linalg.generic {indexing_maps = [#map0, #map0], 334 /// iterator_types = ["parallel", "parallel"]} 335 /// ins(%arg0 : tensor<?x?xf32>) 336 /// outs(%2 : tensor<?x?xf32>) { 337 /// ^bb0(%arg3: f32, %arg4: f32): 338 /// %4 = arith.addf %arg3, %arg3 : f32 339 /// linalg.yield %4 : f32 340 /// } -> tensor<?x?xf32> 341 /// %4 = tensor.pack %3 342 /// inner_dims_pos = [0, 1] 343 /// inner_tiles = [8, 2] 344 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 345 /// 346 /// will be converted to 347 /// 348 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> 349 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> 350 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 351 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> 352 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 353 /// %0 = affine.apply #map()[%dim] 354 /// %1 = affine.apply #map1()[%dim_0] 355 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> 356 /// %pack = tensor.pack %arg0 357 /// inner_dims_pos = [0, 1] 358 /// inner_tiles = [8, 2] 359 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 360 /// %3 = linalg.generic {indexing_maps = [#map2, #map2], 361 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 362 /// ins(%pack : tensor<?x?x8x2xf32>) 363 /// outs(%arg1 : tensor<?x?x8x2xf32>) { 364 /// ^bb0(%in: f32, %out: f32): 365 /// %4 = arith.addf %in, %in : f32 366 /// linalg.yield %4 : f32 367 /// } -> tensor<?x?x8x2xf32> 368 static FailureOr<GenericOp> 369 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, 370 ControlPropagationFn controlFn) { 371 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); 372 if (!genericOp) 373 return failure(); 374 375 // User controlled propagation function. 376 if (!controlFn(genericOp)) 377 return failure(); 378 379 // TODO: Enable propagation in the presence of linalg.index and 380 // tensor.extract, likely as a separate pattern as the pack information and 381 // propagation decision needs to be inferred from the region of the generic. 382 if (hasGatherSemantics(genericOp)) 383 return failure(); 384 385 // TODO: Relax the restriction. We are able to bubble up the pack op through 386 // multi-result generic op. It just needs more work. 387 if (genericOp.getNumResults() != 1) 388 return failure(); 389 390 // Bail-out if the result of the generic has multiple uses, as bubbling up 391 // creates recomputation if the generic has multiple users. 392 // TODO: Enable the case where every use is an identical pack op as no 393 // recomputation is needed in that case. 394 if (!genericOp->getResult(0).hasOneUse()) 395 return failure(); 396 397 // We want to move the pack not the generic. 398 OpBuilder::InsertionGuard guard(rewriter); 399 rewriter.setInsertionPoint(genericOp); 400 401 // We need to handle two cases: 402 // 1) The tensor.pack destination is a tensor.empty. If this is the case, we 403 // create a new tensor.empty to avoid breaking dominance, as we are moving the 404 // tensor.pack above the linalg.generic. 405 // 2) The destination is not a tensor.empty. In this case we can replace only 406 // if the destination of the tensor.pack dominates the linalg.generic. 407 Value packOpDest = packOp.getDest(); 408 if (!packOpDest.hasOneUse()) 409 return failure(); 410 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { 411 packOpDest = rewriter.create<tensor::EmptyOp>( 412 genericOp->getLoc(), emptyOp.getMixedSizes(), 413 emptyOp.getType().getElementType()); 414 } else { 415 DominanceInfo dom(genericOp); 416 if (!dom.properlyDominates(packOpDest, genericOp)) 417 return failure(); 418 } 419 420 // TODO: Add an option for allowing padding values. It could introduce 421 // undefined behavior if we unconditionally propagate pack op through all 422 // the ops. E.g., if the padding value is zero and there are division ops in 423 // a generic op. Some values of padding area could be NaN (0/0). 424 if (packOp.getPaddingValue()) 425 return failure(); 426 427 OpOperand *opOperand = genericOp.getDpsInitOperand(0); 428 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); 429 if (failed(packInfo)) 430 return failure(); 431 432 // Rebuild the indexing map for the corresponding init operand. 433 auto [packedOutOperand, packedOutIndexingMap] = 434 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 435 genericOp, opOperand); 436 437 // We'll replace the init operand with the destination of pack op if the init 438 // operand has not users in the body of the linalg.generic (pure elementwise). 439 // If it has users we need to pack the init operand too and replace the init 440 // with the packing result. 441 Value dest = (genericOp.getRegionOutputArgs()[0].use_empty()) 442 ? packOpDest 443 : packedOutOperand; 444 445 return packElementWiseOp(rewriter, genericOp, dest, packedOutIndexingMap, 446 *packInfo); 447 } 448 449 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. 450 struct BubbleUpPackOpThroughGenericOpPattern 451 : public OpRewritePattern<tensor::PackOp> { 452 public: 453 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, 454 ControlPropagationFn fun) 455 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 456 457 LogicalResult matchAndRewrite(tensor::PackOp packOp, 458 PatternRewriter &rewriter) const override { 459 auto genericOp = 460 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); 461 if (failed(genericOp)) 462 return failure(); 463 rewriter.replaceOp(packOp, genericOp->getResults()); 464 return success(); 465 } 466 467 private: 468 ControlPropagationFn controlFn; 469 }; 470 471 // TODO: Relax this restriction. We should unpack an elementwise also 472 // in the presence of multiple unpack ops as producers. 473 /// Return the unpacked operand, if present, for the current generic op. 474 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { 475 OpOperand *unPackedOperand = nullptr; 476 for (OpOperand &operand : genericOp->getOpOperands()) { 477 auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>(); 478 if (!unPackOp) 479 continue; 480 if (unPackedOperand) 481 return failure(); 482 unPackedOperand = &operand; 483 } 484 if (!unPackedOperand) 485 return failure(); 486 return unPackedOperand; 487 } 488 489 /// Push down a tensor.unpack op through elementwise generic op. 490 /// The new generic op works on packed domain; pack ops are created for input 491 /// and output operands. A tensor.unpack op is inserted right after the packed 492 /// generic. E.g. 493 /// 494 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 495 /// 496 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. 497 /// 498 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 499 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] 500 /// inner_dims_pos = [3] inner_tiles = [32] into %0 501 /// %2 = linalg.generic {indexing_maps = [#map], 502 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 503 /// outs(%1 : tensor<12x56x56x64xf32>) { 504 /// ^bb0(%out : f32): 505 /// linalg.yield %out : f32 506 /// } -> tensor<12x56x56x64xf32> 507 /// 508 /// will be converted to 509 /// 510 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 511 /// 512 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 513 /// %1 = linalg.generic {indexing_maps = [#map], 514 /// iterator_types = ["parallel", "parallel", "parallel", 515 /// "parallel", "parallel"]} 516 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { 517 /// ^bb0(%out : f32): 518 /// linalg.yield %out : f32 519 /// } -> tensor<12x2x56x56x32xf32> 520 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2] 521 /// inner_dims_pos = [3] inner_tiles = [32] into %0 522 /// 523 static FailureOr<std::tuple<GenericOp, Value>> 524 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) { 525 if (genericOp.getNumResults() != 1) 526 return failure(); 527 528 if (hasGatherSemantics(genericOp)) 529 return failure(); 530 531 // Collect the unPacked operand, if present. 532 auto maybeUnPackedOperand = getUnPackedOperand(genericOp); 533 if (failed(maybeUnPackedOperand)) 534 return failure(); 535 OpOperand *unPackedOperand = *(maybeUnPackedOperand); 536 537 // Extract packing information. 538 tensor::UnPackOp producerUnPackOp = 539 unPackedOperand->get().getDefiningOp<tensor::UnPackOp>(); 540 assert(producerUnPackOp && "expect a valid UnPackOp"); 541 auto packInfo = 542 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); 543 if (failed(packInfo)) 544 return failure(); 545 546 // Rebuild the indexing map for the corresponding init operand. 547 auto [packedOutOperand, packedOutIndexingMap] = 548 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 549 genericOp, genericOp.getDpsInitOperand(0)); 550 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>(); 551 552 // If the dps init operand of the generic is a tensor.empty, do not pack it 553 // and forward the new tensor.empty as a destination. 554 Value dest = packedOutOperand; 555 if (auto initTensor = genericOp.getDpsInitOperand(0) 556 ->get() 557 .getDefiningOp<tensor::EmptyOp>()) { 558 if (destPack) 559 dest = destPack.getDest(); 560 } 561 562 // Pack the genericOp. 563 GenericOp newGenericOp = packElementWiseOp(rewriter, genericOp, dest, 564 packedOutIndexingMap, *packInfo); 565 Value newResult = 566 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); 567 568 // If the output is unaffected, no need to unpack. 569 if (!destPack) 570 return std::make_tuple(newGenericOp, newResult); 571 572 auto mixedTiles = destPack.getMixedTiles(); 573 auto innerDimsPos = destPack.getInnerDimsPos(); 574 auto outerDimsPerm = destPack.getOuterDimsPerm(); 575 576 // If the output type for the generic differs from the source 577 // unpack op, we need to create a new destination tensor. In the 578 // dynamic case we always need a new destination. 579 auto loc = genericOp.getLoc(); 580 Value unPackDest = producerUnPackOp.getDest(); 581 auto genericOutType = 582 genericOp.getDpsInitOperand(0)->get().getType().cast<RankedTensorType>(); 583 if (producerUnPackOp.getDestType() != genericOutType || 584 !genericOutType.hasStaticShape()) { 585 unPackDest = tensor::UnPackOp::createDestinationTensor( 586 rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm); 587 } 588 589 // Insert an unPackOp right after the packed generic. 590 Value unPackOpRes = 591 rewriter 592 .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos, 593 mixedTiles, outerDimsPerm) 594 .getResult(); 595 596 return std::make_tuple(newGenericOp, unPackOpRes); 597 } 598 599 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. 600 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { 601 public: 602 PushDownUnPackOpThroughGenericOp(MLIRContext *context, 603 ControlPropagationFn fun) 604 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} 605 606 LogicalResult matchAndRewrite(GenericOp genericOp, 607 PatternRewriter &rewriter) const override { 608 if (!controlFn(genericOp)) 609 return failure(); 610 611 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp); 612 if (failed(genericAndRepl)) 613 return failure(); 614 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); 615 return success(); 616 } 617 618 private: 619 ControlPropagationFn controlFn; 620 }; 621 622 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to 623 /// add as many zero padding dimensions in `high` and `low` based on the number 624 /// of point loops. 625 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { 626 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 627 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} 628 629 LogicalResult matchAndRewrite(tensor::PadOp padOp, 630 PatternRewriter &rewriter) const override { 631 tensor::UnPackOp unpackOp = 632 padOp.getSource().getDefiningOp<tensor::UnPackOp>(); 633 if (!unpackOp) 634 return failure(); 635 636 if (!controlFn(padOp)) 637 return failure(); 638 639 Location loc = padOp.getLoc(); 640 // Bail out if one of the padded dimension is a tiled one. 641 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 642 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); 643 llvm::SmallBitVector innerDims(paddedDims.size()); 644 for (int64_t dim : innerDimsPos) 645 innerDims.flip(dim); 646 if (paddedDims.anyCommon(innerDims)) 647 return failure(); 648 649 Value paddingVal = padOp.getConstantPaddingValue(); 650 if (!paddingVal) 651 return failure(); 652 653 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 654 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); 655 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 656 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 657 if (!outerDimsPerm.empty()) { 658 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 659 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 660 } 661 // Add zero padding for the point loops. 662 size_t pointLoopsSize = innerDimsPos.size(); 663 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 664 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 665 666 auto newPadOp = rewriter.create<tensor::PadOp>( 667 loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, 668 paddingVal, padOp.getNofold()); 669 670 // Inject the tensor.unpack right after the packed padOp. 671 Value outputUnPack = rewriter.create<tensor::EmptyOp>( 672 loc, padOp.getResultType().getShape(), 673 padOp.getResultType().getElementType()); 674 675 Value replacement = rewriter.create<tensor::UnPackOp>( 676 loc, newPadOp.getResult(), outputUnPack, innerDimsPos, 677 unpackOp.getMixedTiles(), outerDimsPerm); 678 rewriter.replaceOp(padOp, replacement); 679 return success(); 680 } 681 682 private: 683 ControlPropagationFn controlFn; 684 }; 685 686 } // namespace 687 688 void mlir::linalg::populateDataLayoutPropagationPatterns( 689 RewritePatternSet &patterns, 690 const ControlPropagationFn &controlPackUnPackPropagation) { 691 patterns.insert<BubbleUpPackOpThroughGenericOpPattern, 692 PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>( 693 patterns.getContext(), controlPackUnPackPropagation); 694 } 695