1 //===- DataLayoutPropagation.cpp -----------------------------------------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Tensor/Utils/Utils.h" 17 #include "mlir/Dialect/Utils/IndexingUtils.h" 18 #include "mlir/IR/Dominance.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/Support/Debug.h" 21 #include <optional> 22 23 namespace mlir { 24 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION 25 #include "mlir/Dialect/Linalg/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::linalg; 30 31 #define DEBUG_TYPE "linalg-data-layout-propagation" 32 33 namespace { 34 35 static bool hasGatherSemantics(linalg::GenericOp genericOp) { 36 for (Operation &op : genericOp.getBody()->getOperations()) 37 if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) 38 return true; 39 return false; 40 } 41 42 // The struct contains the infomation about mapping packing information to 43 // the iteration domain of Linalg ops. 44 struct PackInfo { 45 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; 46 // InnerDimsPos on iteration domain, which follows the order in pack ops. 47 SmallVector<int64_t> tiledDimsPos; 48 // The sizes of tiling data dimensions on iteration domain. 49 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; 50 // The mapping from a dimension of iteration domain to the corresponding inner 51 // tiling dimension on iteration domain. 52 llvm::DenseMap<int64_t, int64_t> tileToPointMapping; 53 // The permutation of outer dims (on domain). 54 SmallVector<int64_t> outerDimsOnDomainPerm; 55 }; 56 57 template <typename OpTy> 58 static FailureOr<PackInfo> 59 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, 60 OpTy packOrUnPackOp) { 61 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value, 62 "applies to only pack or unpack operations"); 63 LLVM_DEBUG( 64 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); 65 66 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); 67 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 68 SmallVector<utils::IteratorType> iterators = 69 genericOp.getIteratorTypesArray(); 70 71 PackInfo packInfo; 72 int64_t origNumDims = indexingMap.getNumDims(); 73 SmallVector<AffineExpr> exprs(indexingMap.getResults()); 74 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); 75 for (auto [index, innerDimPos, tileSize] : 76 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), 77 innerDimsPos, packOrUnPackOp.getMixedTiles())) { 78 auto expr = exprs[innerDimPos]; 79 if (!isa<AffineDimExpr>(expr)) 80 return failure(); 81 int64_t domainDimPos = 82 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition(); 83 if (!isParallelIterator(iterators[domainDimPos])) 84 return failure(); 85 packInfo.tiledDimsPos.push_back(domainDimPos); 86 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; 87 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; 88 LLVM_DEBUG({ 89 llvm::dbgs() << "map innerDimPos=" << innerDimPos 90 << " to iteration dimension (d" << domainDimPos << ", d" 91 << packInfo.tileToPointMapping[domainDimPos] 92 << "), which has size=(" 93 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; 94 }); 95 } 96 97 // Bail out if a tiled dimension is present in a map but not as an affine dim 98 // expression. 99 auto areAllAffineDimExpr = [&](int dim) { 100 for (AffineMap map : indexingMaps) { 101 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { 102 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr); 103 })) { 104 return false; 105 } 106 } 107 return true; 108 }; 109 for (int64_t i : packInfo.tiledDimsPos) 110 if (!areAllAffineDimExpr(i)) 111 return failure(); 112 113 // Get the outer dims perm on the iteration domain. Start by identifying the 114 // set of domain dims affected by the outer permutation along with the 115 // permuted ordering for those dims. Then the full outer dims permutation can 116 // be constructed by replacing the affected dims with the permuted result in a 117 // numLoops-rank identity. e.g. 118 // outerDimsPerm = [1, 2, 0] 119 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) 120 // 121 // permutedOuterDims = [4, 3, 1] 122 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] 123 // 124 // Non-affine dim expressions must not be permuted by the outer dims 125 // permutation. 126 SmallVector<int64_t> permutedOuterDims; 127 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { 128 auto permutedExpr = indexingMap.getResult(dim); 129 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) { 130 permutedOuterDims.push_back(dimExpr.getPosition()); 131 continue; 132 } 133 134 // TODO: Allow propagation with transposes on non affine dim expressions, 135 // e.g. d0 + d1 which implies transposing both dims simultaneously while 136 // maintaining the relative position between them. 137 if (static_cast<int64_t>(index) != dim) 138 return failure(); 139 } 140 if (!permutedOuterDims.empty()) { 141 int64_t outerDimIndex = 0; 142 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), 143 permutedOuterDims.end()); 144 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) 145 packInfo.outerDimsOnDomainPerm.push_back( 146 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] 147 : i); 148 LLVM_DEBUG({ 149 llvm::dbgs() << "map outer dimsDimsPerm to "; 150 for (auto dim : packInfo.outerDimsOnDomainPerm) 151 llvm::dbgs() << dim << " "; 152 llvm::dbgs() << "\n"; 153 }); 154 } 155 156 return packInfo; 157 } 158 159 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, 160 ArrayRef<AffineExpr> exprs) { 161 // Compute `outer_dims_perm`. See example: 162 // current exprs : (d0, d1, d2, d3) -> (d2, d3) 163 // perm : [0, 3, 1, 2] 164 // First map d2, d3 with their position in the array as: 165 // currentPositionTileLoops: dim | pos 166 // d2 | 0 167 // d3 | 1 168 // then scan `perm` in order and get the `outer_dims_perm` 169 // to be used, here it would be [1, 0]. 170 assert(!perm.empty() && "expect perm not to be empty"); 171 assert(!exprs.empty() && "expect exprs not to be empty"); 172 if (exprs.size() == 1) 173 return {}; 174 SmallVector<int64_t> outerDimsPerm; 175 DenseMap<int64_t, int64_t> currentPositionTileLoops; 176 for (auto [pos, expr] : llvm::enumerate(exprs)) { 177 // Here we rely on the assumption that the outer dims permutation 178 // when propagating currently requires that non-affine dim expressions 179 // are not permuted, thus allowing the identity assignment below. 180 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) 181 currentPositionTileLoops[dimExpr.getPosition()] = pos; 182 else 183 currentPositionTileLoops[pos] = pos; 184 } 185 for (int64_t loopIdx : perm) { 186 if (currentPositionTileLoops.count(loopIdx)) 187 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); 188 } 189 return outerDimsPerm; 190 } 191 192 /// Returns a tuple for packed operand and indexing_map with the assumptions: 193 /// 1) The generic op is the producer of the pack op. 194 /// 2) The generic op has only one result. 195 /// If the operand is a scalar or packing dimensions are all irrelevant to the 196 /// operand, the operand and the updated indexing map will be returned. 197 /// Otherwise, it returns the packed operand and the updated indexing map. E.g., 198 /// 199 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 200 /// #map1 = affine_map<(d0, d1) -> (d0)> 201 /// #map2 = affine_map<(d0, d1) -> (d1)> 202 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], 203 /// iterator_types = ["parallel", "parallel"]} 204 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) 205 /// outs(%init : tensor<?x?xf32>) { 206 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): 207 /// %4 = arith.addf %arg3, %arg4 : f32 208 /// linalg.yield %4 : f32 209 /// } -> tensor<?x?xf32> 210 /// %1 = tensor.pack %0 211 /// inner_dims_pos = [0, 1] 212 /// inner_tiles = [8, 2] 213 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 214 /// 215 /// Taking the first input operand as an example, the inner tile size of d1 is 216 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> 217 /// affine_map<(d1, d3)>` will be returned. 218 /// 219 /// %pack = tensor.pack %arg0 220 /// inner_dims_pos = [0] 221 /// inner_tiles = [8] 222 /// into %init : tensor<?xf32> -> tensor<?x8xf32> 223 static std::tuple<Value, AffineMap> 224 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, 225 GenericOp genericOp, OpOperand *opOperand) { 226 int64_t numOrigLoops = genericOp.getNumLoops(); 227 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 228 int64_t numLoops = numOrigLoops + numInnerLoops; 229 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); 230 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; 231 SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); 232 233 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. 234 if (genericOp.isScalar(opOperand) || exprs.empty()) 235 return std::make_tuple(opOperand->get(), 236 AffineMap::get(numLoops, 0, exprs, b.getContext())); 237 238 // Step 1. Construct the information of packing data dimensions; append inner 239 // dimensions to the indexing maps for the operand. 240 for (auto [index, expr] : llvm::enumerate(exprs)) { 241 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 242 int64_t dimPos = dimExpr.getPosition(); 243 domainDimToOperandDim[dimPos] = index; 244 continue; 245 } 246 } 247 SmallVector<int64_t> innerDimsPos; 248 SmallVector<OpFoldResult> innerTileSizes; 249 for (auto dimPos : packInfo.tiledDimsPos) { 250 if (!domainDimToOperandDim.count(dimPos)) 251 continue; 252 int64_t index = domainDimToOperandDim[dimPos]; 253 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); 254 innerDimsPos.push_back(index); 255 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); 256 } 257 258 // Step 2. Handle outer dim permutations. 259 SmallVector<int64_t> outerDimsPerm; 260 if (!packInfo.outerDimsOnDomainPerm.empty()) { 261 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); 262 263 // Step 2.1: Fold transpose into the linalg.generic. 264 SmallVector<int64_t> inversedOuterPerm = 265 invertPermutationVector(packInfo.outerDimsOnDomainPerm); 266 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { 267 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) { 268 int64_t dimPos = dimExpr.getPosition(); 269 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); 270 continue; 271 } 272 assert(isa<AffineConstantExpr>(exprs[i]) && 273 "Attempted to permute non-constant and non-affine dim expression"); 274 } 275 // Step 2.2: Undo the transposition on `exprs` and propagate the 276 // transposition on the pack using outerDimsPerm. 277 if (!outerDimsPerm.empty()) { 278 SmallVector<AffineExpr> auxVec = exprs; 279 for (const auto &en : enumerate(outerDimsPerm)) 280 auxVec[en.index()] = exprs[en.value()]; 281 exprs = auxVec; 282 } 283 } 284 auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); 285 286 // The operand does not have dimensions that relates to pack op. 287 if (innerDimsPos.empty() && outerDimsPerm.empty()) 288 return std::make_tuple(opOperand->get(), indexingMap); 289 290 auto empty = tensor::PackOp::createDestinationTensor( 291 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); 292 auto packedOperand = b.create<tensor::PackOp>( 293 loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, 294 /*padding=*/std::nullopt, outerDimsPerm); 295 return std::make_tuple(packedOperand, indexingMap); 296 } 297 298 /// Pack a genericOp and return it. 299 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, 300 Value dest, AffineMap packedOutIndexingMap, 301 const PackInfo &packInfo) { 302 Location loc = genericOp.getLoc(); 303 SmallVector<Value> inputOperands; 304 SmallVector<AffineMap> indexingMaps; 305 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { 306 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( 307 rewriter, loc, packInfo, genericOp, inputOperand); 308 inputOperands.push_back(packedOperand); 309 indexingMaps.push_back(packedIndexingMap); 310 } 311 312 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 313 SmallVector<utils::IteratorType> iterTypes = 314 genericOp.getIteratorTypesArray(); 315 iterTypes.append(numInnerLoops, utils::IteratorType::parallel); 316 317 indexingMaps.push_back(packedOutIndexingMap); 318 319 auto newGenericOp = rewriter.create<linalg::GenericOp>( 320 loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, 321 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); 322 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), 323 newGenericOp.getRegion().begin()); 324 return newGenericOp; 325 } 326 327 /// Bubbles up tensor.pack op through a producer generic op. This 328 /// swap pack(generic) to generic(pack). The new generic op works on packed 329 /// domain; pack ops are created for input and output operands. E.g., 330 /// 331 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 332 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 333 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 334 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 335 /// %3 = linalg.generic {indexing_maps = [#map0, #map0], 336 /// iterator_types = ["parallel", "parallel"]} 337 /// ins(%arg0 : tensor<?x?xf32>) 338 /// outs(%2 : tensor<?x?xf32>) { 339 /// ^bb0(%arg3: f32, %arg4: f32): 340 /// %4 = arith.addf %arg3, %arg3 : f32 341 /// linalg.yield %4 : f32 342 /// } -> tensor<?x?xf32> 343 /// %4 = tensor.pack %3 344 /// inner_dims_pos = [0, 1] 345 /// inner_tiles = [8, 2] 346 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 347 /// 348 /// will be converted to 349 /// 350 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> 351 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> 352 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 353 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> 354 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 355 /// %0 = affine.apply #map()[%dim] 356 /// %1 = affine.apply #map1()[%dim_0] 357 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> 358 /// %pack = tensor.pack %arg0 359 /// inner_dims_pos = [0, 1] 360 /// inner_tiles = [8, 2] 361 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 362 /// %3 = linalg.generic {indexing_maps = [#map2, #map2], 363 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 364 /// ins(%pack : tensor<?x?x8x2xf32>) 365 /// outs(%arg1 : tensor<?x?x8x2xf32>) { 366 /// ^bb0(%in: f32, %out: f32): 367 /// %4 = arith.addf %in, %in : f32 368 /// linalg.yield %4 : f32 369 /// } -> tensor<?x?x8x2xf32> 370 static FailureOr<GenericOp> 371 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, 372 ControlPropagationFn controlFn) { 373 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); 374 if (!genericOp) 375 return failure(); 376 377 // User controlled propagation function. 378 if (!controlFn(genericOp)) 379 return failure(); 380 381 // TODO: Enable propagation in the presence of linalg.index and 382 // tensor.extract, likely as a separate pattern as the pack information and 383 // propagation decision needs to be inferred from the region of the generic. 384 if (hasGatherSemantics(genericOp)) 385 return failure(); 386 387 // TODO: Relax the restriction. We are able to bubble up the pack op through 388 // multi-result generic op. It just needs more work. 389 if (genericOp.getNumResults() != 1) 390 return failure(); 391 392 // Bail-out if the result of the generic has multiple uses, as bubbling up 393 // creates recomputation if the generic has multiple users. 394 // TODO: Enable the case where every use is an identical pack op as no 395 // recomputation is needed in that case. 396 if (!genericOp->getResult(0).hasOneUse()) 397 return failure(); 398 399 // We want to move the pack not the generic. 400 OpBuilder::InsertionGuard guard(rewriter); 401 rewriter.setInsertionPoint(genericOp); 402 403 // We need to handle two cases: 404 // 1) The tensor.pack destination is a tensor.empty. If this is the case, we 405 // create a new tensor.empty to avoid breaking dominance, as we are moving the 406 // tensor.pack above the linalg.generic. 407 // 2) The destination is not a tensor.empty. In this case we can replace only 408 // if the destination of the tensor.pack dominates the linalg.generic. 409 Value packOpDest = packOp.getDest(); 410 if (!packOpDest.hasOneUse()) 411 return failure(); 412 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { 413 packOpDest = rewriter.create<tensor::EmptyOp>( 414 genericOp->getLoc(), emptyOp.getMixedSizes(), 415 emptyOp.getType().getElementType()); 416 } else { 417 DominanceInfo dom(genericOp); 418 if (!dom.properlyDominates(packOpDest, genericOp)) 419 return failure(); 420 } 421 422 // TODO: Add an option for allowing padding values. It could introduce 423 // undefined behavior if we unconditionally propagate pack op through all 424 // the ops. E.g., if the padding value is zero and there are division ops in 425 // a generic op. Some values of padding area could be NaN (0/0). 426 if (packOp.getPaddingValue()) 427 return failure(); 428 429 OpOperand *opOperand = genericOp.getDpsInitOperand(0); 430 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); 431 if (failed(packInfo)) 432 return failure(); 433 434 // Rebuild the indexing map for the corresponding init operand. 435 auto [packedOutOperand, packedOutIndexingMap] = 436 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 437 genericOp, opOperand); 438 439 // If the dps init operand of the generic is a tensor.empty forward the pack 440 // op destination. 441 Value dest = packedOutOperand; 442 if (auto initTensor = genericOp.getDpsInitOperand(0) 443 ->get() 444 .getDefiningOp<tensor::EmptyOp>()) { 445 dest = packOpDest; 446 } 447 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, 448 *packInfo); 449 } 450 451 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. 452 struct BubbleUpPackOpThroughGenericOpPattern 453 : public OpRewritePattern<tensor::PackOp> { 454 public: 455 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, 456 ControlPropagationFn fun) 457 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 458 459 LogicalResult matchAndRewrite(tensor::PackOp packOp, 460 PatternRewriter &rewriter) const override { 461 auto genericOp = 462 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); 463 if (failed(genericOp)) 464 return failure(); 465 rewriter.replaceOp(packOp, genericOp->getResults()); 466 return success(); 467 } 468 469 private: 470 ControlPropagationFn controlFn; 471 }; 472 473 // TODO: Relax this restriction. We should unpack a generic op also 474 // in the presence of multiple unpack ops as producers. 475 /// Return the unpacked operand, if present, for the current generic op. 476 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { 477 OpOperand *unPackedOperand = nullptr; 478 for (OpOperand &operand : genericOp->getOpOperands()) { 479 auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>(); 480 if (!unPackOp) 481 continue; 482 if (unPackedOperand) 483 return failure(); 484 unPackedOperand = &operand; 485 } 486 if (!unPackedOperand) 487 return failure(); 488 return unPackedOperand; 489 } 490 491 /// Push down a tensor.unpack op through a generic op. 492 /// The new generic op works on packed domain; pack ops are created for input 493 /// and output operands. A tensor.unpack op is inserted right after the packed 494 /// generic. E.g. 495 /// 496 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 497 /// 498 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. 499 /// 500 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 501 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] 502 /// inner_dims_pos = [3] inner_tiles = [32] into %0 503 /// %2 = linalg.generic {indexing_maps = [#map], 504 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 505 /// outs(%1 : tensor<12x56x56x64xf32>) { 506 /// ^bb0(%out : f32): 507 /// linalg.yield %out : f32 508 /// } -> tensor<12x56x56x64xf32> 509 /// 510 /// will be converted to 511 /// 512 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 513 /// 514 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 515 /// %1 = linalg.generic {indexing_maps = [#map], 516 /// iterator_types = ["parallel", "parallel", "parallel", 517 /// "parallel", "parallel"]} 518 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { 519 /// ^bb0(%out : f32): 520 /// linalg.yield %out : f32 521 /// } -> tensor<12x2x56x56x32xf32> 522 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2] 523 /// inner_dims_pos = [3] inner_tiles = [32] into %0 524 /// 525 static FailureOr<std::tuple<GenericOp, Value>> 526 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) { 527 if (genericOp.getNumResults() != 1) 528 return failure(); 529 530 if (hasGatherSemantics(genericOp)) 531 return failure(); 532 533 // Collect the unPacked operand, if present. 534 auto maybeUnPackedOperand = getUnPackedOperand(genericOp); 535 if (failed(maybeUnPackedOperand)) 536 return failure(); 537 OpOperand *unPackedOperand = *(maybeUnPackedOperand); 538 539 // Extract packing information. 540 tensor::UnPackOp producerUnPackOp = 541 unPackedOperand->get().getDefiningOp<tensor::UnPackOp>(); 542 assert(producerUnPackOp && "expect a valid UnPackOp"); 543 auto packInfo = 544 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); 545 if (failed(packInfo)) 546 return failure(); 547 548 // Rebuild the indexing map for the corresponding init operand. 549 auto [packedOutOperand, packedOutIndexingMap] = 550 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 551 genericOp, genericOp.getDpsInitOperand(0)); 552 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>(); 553 554 // If the dps init operand of the generic is a tensor.empty, do not pack it 555 // and forward the new tensor.empty as a destination. 556 Value dest = packedOutOperand; 557 if (auto initTensor = genericOp.getDpsInitOperand(0) 558 ->get() 559 .getDefiningOp<tensor::EmptyOp>()) { 560 if (destPack) 561 dest = destPack.getDest(); 562 } 563 564 // Pack the genericOp. 565 GenericOp newGenericOp = 566 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); 567 Value newResult = 568 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); 569 570 // If the output is unaffected, no need to unpack. 571 if (!destPack) 572 return std::make_tuple(newGenericOp, newResult); 573 574 auto mixedTiles = destPack.getMixedTiles(); 575 auto innerDimsPos = destPack.getInnerDimsPos(); 576 auto outerDimsPerm = destPack.getOuterDimsPerm(); 577 578 // If the output type for the generic differs from the source 579 // unpack op, we need to create a new destination tensor. In the 580 // dynamic case we always need a new destination. 581 auto loc = genericOp.getLoc(); 582 Value unPackDest = producerUnPackOp.getDest(); 583 auto genericOutType = 584 cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType()); 585 if (producerUnPackOp.getDestType() != genericOutType || 586 !genericOutType.hasStaticShape()) { 587 unPackDest = tensor::UnPackOp::createDestinationTensor( 588 rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm); 589 } 590 591 // Insert an unPackOp right after the packed generic. 592 Value unPackOpRes = 593 rewriter 594 .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos, 595 mixedTiles, outerDimsPerm) 596 .getResult(); 597 598 return std::make_tuple(newGenericOp, unPackOpRes); 599 } 600 601 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. 602 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { 603 public: 604 PushDownUnPackOpThroughGenericOp(MLIRContext *context, 605 ControlPropagationFn fun) 606 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} 607 608 LogicalResult matchAndRewrite(GenericOp genericOp, 609 PatternRewriter &rewriter) const override { 610 if (!controlFn(genericOp)) 611 return failure(); 612 613 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp); 614 if (failed(genericAndRepl)) 615 return failure(); 616 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); 617 return success(); 618 } 619 620 private: 621 ControlPropagationFn controlFn; 622 }; 623 624 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to 625 /// add as many zero padding dimensions in `high` and `low` based on the number 626 /// of point loops. 627 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { 628 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 629 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} 630 631 LogicalResult matchAndRewrite(tensor::PadOp padOp, 632 PatternRewriter &rewriter) const override { 633 tensor::UnPackOp unpackOp = 634 padOp.getSource().getDefiningOp<tensor::UnPackOp>(); 635 if (!unpackOp) 636 return failure(); 637 638 if (!controlFn(padOp)) 639 return failure(); 640 641 Location loc = padOp.getLoc(); 642 // Bail out if one of the padded dimension is a tiled one. 643 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 644 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); 645 llvm::SmallBitVector innerDims(paddedDims.size()); 646 for (int64_t dim : innerDimsPos) 647 innerDims.flip(dim); 648 if (paddedDims.anyCommon(innerDims)) 649 return failure(); 650 651 Value paddingVal = padOp.getConstantPaddingValue(); 652 if (!paddingVal) 653 return failure(); 654 655 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 656 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); 657 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 658 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 659 if (!outerDimsPerm.empty()) { 660 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 661 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 662 } 663 // Add zero padding for the point loops. 664 size_t pointLoopsSize = innerDimsPos.size(); 665 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 666 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 667 668 auto newPadOp = rewriter.create<tensor::PadOp>( 669 loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, 670 paddingVal, padOp.getNofold()); 671 672 // Inject the tensor.unpack right after the packed padOp. 673 Value outputUnPack = rewriter.create<tensor::EmptyOp>( 674 loc, padOp.getResultType().getShape(), 675 padOp.getResultType().getElementType()); 676 677 Value replacement = rewriter.create<tensor::UnPackOp>( 678 loc, newPadOp.getResult(), outputUnPack, innerDimsPos, 679 unpackOp.getMixedTiles(), outerDimsPerm); 680 rewriter.replaceOp(padOp, replacement); 681 return success(); 682 } 683 684 private: 685 ControlPropagationFn controlFn; 686 }; 687 688 } // namespace 689 690 void mlir::linalg::populateDataLayoutPropagationPatterns( 691 RewritePatternSet &patterns, 692 const ControlPropagationFn &controlPackUnPackPropagation) { 693 patterns.insert<BubbleUpPackOpThroughGenericOpPattern, 694 PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>( 695 patterns.getContext(), controlPackUnPackPropagation); 696 } 697