1 //===- ExpandStridedMetadata.cpp - Simplify this operation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// The pass expands memref operations that modify the metadata of a memref 10 /// (sizes, offset, strides) into a sequence of easier to analyze constructs. 11 /// In particular, this pass transforms operations into explicit sequence of 12 /// operations that model the effect of this operation on the different 13 /// metadata. This pass uses affine constructs to materialize these effects. 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 20 #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/BuiltinTypes.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ADT/SmallBitVector.h" 27 #include <optional> 28 29 namespace mlir { 30 namespace memref { 31 #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA 32 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 33 } // namespace memref 34 } // namespace mlir 35 36 using namespace mlir; 37 using namespace mlir::affine; 38 39 namespace { 40 41 struct StridedMetadata { 42 Value basePtr; 43 OpFoldResult offset; 44 SmallVector<OpFoldResult> sizes; 45 SmallVector<OpFoldResult> strides; 46 }; 47 48 /// From `subview(memref, subOffset, subSizes, subStrides))` compute 49 /// 50 /// \verbatim 51 /// baseBuffer, baseOffset, baseSizes, baseStrides = 52 /// extract_strided_metadata(memref) 53 /// strides#i = baseStrides#i * subStrides#i 54 /// offset = baseOffset + sum(subOffset#i * baseStrides#i) 55 /// sizes = subSizes 56 /// \endverbatim 57 /// 58 /// and return {baseBuffer, offset, sizes, strides} 59 static FailureOr<StridedMetadata> 60 resolveSubviewStridedMetadata(RewriterBase &rewriter, 61 memref::SubViewOp subview) { 62 // Build a plain extract_strided_metadata(memref) from subview(memref). 63 Location origLoc = subview.getLoc(); 64 Value source = subview.getSource(); 65 auto sourceType = cast<MemRefType>(source.getType()); 66 unsigned sourceRank = sourceType.getRank(); 67 68 auto newExtractStridedMetadata = 69 rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source); 70 71 auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); 72 #ifndef NDEBUG 73 auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset(); 74 #endif // NDEBUG 75 76 // Compute the new strides and offset from the base strides and offset: 77 // newStride#i = baseStride#i * subStride#i 78 // offset = baseOffset + sum(subOffsets#i * newStrides#i) 79 SmallVector<OpFoldResult> strides; 80 SmallVector<OpFoldResult> subStrides = subview.getMixedStrides(); 81 auto origStrides = newExtractStridedMetadata.getStrides(); 82 83 // Hold the affine symbols and values for the computation of the offset. 84 SmallVector<OpFoldResult> values(2 * sourceRank + 1); 85 SmallVector<AffineExpr> symbols(2 * sourceRank + 1); 86 87 bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols}); 88 AffineExpr expr = symbols.front(); 89 values[0] = ShapedType::isDynamic(sourceOffset) 90 ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) 91 : rewriter.getIndexAttr(sourceOffset); 92 SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets(); 93 94 AffineExpr s0 = rewriter.getAffineSymbolExpr(0); 95 AffineExpr s1 = rewriter.getAffineSymbolExpr(1); 96 for (unsigned i = 0; i < sourceRank; ++i) { 97 // Compute the stride. 98 OpFoldResult origStride = 99 ShapedType::isDynamic(sourceStrides[i]) 100 ? origStrides[i] 101 : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); 102 strides.push_back(makeComposedFoldedAffineApply( 103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride})); 104 105 // Build up the computation of the offset. 106 unsigned baseIdxForDim = 1 + 2 * i; 107 unsigned subOffsetForDim = baseIdxForDim; 108 unsigned origStrideForDim = baseIdxForDim + 1; 109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; 110 values[subOffsetForDim] = subOffsets[i]; 111 values[origStrideForDim] = origStride; 112 } 113 114 // Compute the offset. 115 OpFoldResult finalOffset = 116 makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); 117 #ifndef NDEBUG 118 // Assert that the computed offset matches the offset of the result type of 119 // the subview op (if both are static). 120 std::optional<int64_t> computedOffset = getConstantIntValue(finalOffset); 121 if (computedOffset && !ShapedType::isDynamic(resultOffset)) 122 assert(*computedOffset == resultOffset && 123 "mismatch between computed offset and result type offset"); 124 #endif // NDEBUG 125 126 // The final result is <baseBuffer, offset, sizes, strides>. 127 // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all 128 // the values. 129 auto subType = cast<MemRefType>(subview.getType()); 130 unsigned subRank = subType.getRank(); 131 132 // The sizes of the final type are defined directly by the input sizes of 133 // the subview. 134 // Moreover subviews can drop some dimensions, some strides and sizes may 135 // not end up in the final <base, offset, sizes, strides> value that we are 136 // replacing. 137 // Do the filtering here. 138 SmallVector<OpFoldResult> subSizes = subview.getMixedSizes(); 139 llvm::SmallBitVector droppedDims = subview.getDroppedDims(); 140 141 SmallVector<OpFoldResult> finalSizes; 142 finalSizes.reserve(subRank); 143 144 SmallVector<OpFoldResult> finalStrides; 145 finalStrides.reserve(subRank); 146 147 #ifndef NDEBUG 148 // Iteration variable for result dimensions of the subview op. 149 int64_t j = 0; 150 #endif // NDEBUG 151 for (unsigned i = 0; i < sourceRank; ++i) { 152 if (droppedDims.test(i)) 153 continue; 154 155 finalSizes.push_back(subSizes[i]); 156 finalStrides.push_back(strides[i]); 157 #ifndef NDEBUG 158 // Assert that the computed stride matches the stride of the result type of 159 // the subview op (if both are static). 160 std::optional<int64_t> computedStride = getConstantIntValue(strides[i]); 161 if (computedStride && !ShapedType::isDynamic(resultStrides[j])) 162 assert(*computedStride == resultStrides[j] && 163 "mismatch between computed stride and result type stride"); 164 ++j; 165 #endif // NDEBUG 166 } 167 assert(finalSizes.size() == subRank && 168 "Should have populated all the values at this point"); 169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset, 170 finalSizes, finalStrides}; 171 } 172 173 /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))` 174 /// With 175 /// 176 /// \verbatim 177 /// baseBuffer, baseOffset, baseSizes, baseStrides = 178 /// extract_strided_metadata(memref) 179 /// strides#i = baseStrides#i * subSizes#i 180 /// offset = baseOffset + sum(subOffset#i * baseStrides#i) 181 /// sizes = subSizes 182 /// dst = reinterpret_cast baseBuffer, offset, sizes, strides 183 /// \endverbatim 184 /// 185 /// In other words, get rid of the subview in that expression and canonicalize 186 /// on its effects on the offset, the sizes, and the strides using affine.apply. 187 struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> { 188 public: 189 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern; 190 191 LogicalResult matchAndRewrite(memref::SubViewOp subview, 192 PatternRewriter &rewriter) const override { 193 FailureOr<StridedMetadata> stridedMetadata = 194 resolveSubviewStridedMetadata(rewriter, subview); 195 if (failed(stridedMetadata)) { 196 return rewriter.notifyMatchFailure(subview, 197 "failed to resolve subview metadata"); 198 } 199 200 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( 201 subview, subview.getType(), stridedMetadata->basePtr, 202 stridedMetadata->offset, stridedMetadata->sizes, 203 stridedMetadata->strides); 204 return success(); 205 } 206 }; 207 208 /// Pattern to replace `extract_strided_metadata(subview)` 209 /// With 210 /// 211 /// \verbatim 212 /// baseBuffer, baseOffset, baseSizes, baseStrides = 213 /// extract_strided_metadata(memref) 214 /// strides#i = baseStrides#i * subSizes#i 215 /// offset = baseOffset + sum(subOffset#i * baseStrides#i) 216 /// sizes = subSizes 217 /// \verbatim 218 /// 219 /// with `baseBuffer`, `offset`, `sizes` and `strides` being 220 /// the replacements for the original `extract_strided_metadata`. 221 struct ExtractStridedMetadataOpSubviewFolder 222 : OpRewritePattern<memref::ExtractStridedMetadataOp> { 223 using OpRewritePattern::OpRewritePattern; 224 225 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, 226 PatternRewriter &rewriter) const override { 227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>(); 228 if (!subviewOp) 229 return failure(); 230 231 FailureOr<StridedMetadata> stridedMetadata = 232 resolveSubviewStridedMetadata(rewriter, subviewOp); 233 if (failed(stridedMetadata)) { 234 return rewriter.notifyMatchFailure( 235 op, "failed to resolve metadata in terms of source subview op"); 236 } 237 Location loc = subviewOp.getLoc(); 238 SmallVector<Value> results; 239 results.reserve(subviewOp.getType().getRank() * 2 + 2); 240 results.push_back(stridedMetadata->basePtr); 241 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, 242 stridedMetadata->offset)); 243 results.append( 244 getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); 245 results.append(getValueOrCreateConstantIndexOp(rewriter, loc, 246 stridedMetadata->strides)); 247 rewriter.replaceOp(op, results); 248 249 return success(); 250 } 251 }; 252 253 /// Compute the expanded sizes of the given \p expandShape for the 254 /// \p groupId-th reassociation group. 255 /// \p origSizes hold the sizes of the source shape as values. 256 /// This is used to compute the new sizes in cases of dynamic shapes. 257 /// 258 /// sizes#i = 259 /// baseSizes#groupId / product(expandShapeSizes#j, 260 /// for j in group excluding reassIdx#i) 261 /// Where reassIdx#i is the reassociation index at index i in \p groupId. 262 /// 263 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() 264 /// 265 /// TODO: Move this utility function directly within ExpandShapeOp. For now, 266 /// this is not possible because this function uses the Affine dialect and the 267 /// MemRef dialect cannot depend on the Affine dialect. 268 static SmallVector<OpFoldResult> 269 getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, 270 ArrayRef<OpFoldResult> origSizes, unsigned groupId) { 271 SmallVector<int64_t, 2> reassocGroup = 272 expandShape.getReassociationIndices()[groupId]; 273 assert(!reassocGroup.empty() && 274 "Reassociation group should have at least one dimension"); 275 276 unsigned groupSize = reassocGroup.size(); 277 SmallVector<OpFoldResult> expandedSizes(groupSize); 278 279 uint64_t productOfAllStaticSizes = 1; 280 std::optional<unsigned> dynSizeIdx; 281 MemRefType expandShapeType = expandShape.getResultType(); 282 283 // Fill up all the statically known sizes. 284 for (unsigned i = 0; i < groupSize; ++i) { 285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); 286 if (ShapedType::isDynamic(dimSize)) { 287 assert(!dynSizeIdx && "There must be at most one dynamic size per group"); 288 dynSizeIdx = i; 289 continue; 290 } 291 productOfAllStaticSizes *= dimSize; 292 expandedSizes[i] = builder.getIndexAttr(dimSize); 293 } 294 295 // Compute the dynamic size using the original size and all the other known 296 // static sizes: 297 // expandSize = origSize / productOfAllStaticSizes. 298 if (dynSizeIdx) { 299 AffineExpr s0 = builder.getAffineSymbolExpr(0); 300 expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply( 301 builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes), 302 origSizes[groupId]); 303 } 304 305 return expandedSizes; 306 } 307 308 /// Compute the expanded strides of the given \p expandShape for the 309 /// \p groupId-th reassociation group. 310 /// \p origStrides and \p origSizes hold respectively the strides and sizes 311 /// of the source shape as values. 312 /// This is used to compute the strides in cases of dynamic shapes and/or 313 /// dynamic stride for this reassociation group. 314 /// 315 /// strides#i = 316 /// origStrides#reassDim * product(expandShapeSizes#j, for j in 317 /// reassIdx#i+1..reassIdx#i+group.size-1) 318 /// 319 /// Where reassIdx#i is the reassociation index for at index i in \p groupId 320 /// and expandShapeSizes#j is either: 321 /// - The constant size at dimension j, derived directly from the result type of 322 /// the expand_shape op, or 323 /// - An affine expression: baseSizes#reassDim / product of all constant sizes 324 /// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic 325 /// element.) 326 /// 327 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() 328 /// 329 /// TODO: Move this utility function directly within ExpandShapeOp. For now, 330 /// this is not possible because this function uses the Affine dialect and the 331 /// MemRef dialect cannot depend on the Affine dialect. 332 SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape, 333 OpBuilder &builder, 334 ArrayRef<OpFoldResult> origSizes, 335 ArrayRef<OpFoldResult> origStrides, 336 unsigned groupId) { 337 SmallVector<int64_t, 2> reassocGroup = 338 expandShape.getReassociationIndices()[groupId]; 339 assert(!reassocGroup.empty() && 340 "Reassociation group should have at least one dimension"); 341 342 unsigned groupSize = reassocGroup.size(); 343 MemRefType expandShapeType = expandShape.getResultType(); 344 345 std::optional<int64_t> dynSizeIdx; 346 347 // Fill up the expanded strides, with the information we can deduce from the 348 // resulting shape. 349 uint64_t currentStride = 1; 350 SmallVector<OpFoldResult> expandedStrides(groupSize); 351 for (int i = groupSize - 1; i >= 0; --i) { 352 expandedStrides[i] = builder.getIndexAttr(currentStride); 353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); 354 if (ShapedType::isDynamic(dimSize)) { 355 assert(!dynSizeIdx && "There must be at most one dynamic size per group"); 356 dynSizeIdx = i; 357 continue; 358 } 359 360 currentStride *= dimSize; 361 } 362 363 // Collect the statically known information about the original stride. 364 Value source = expandShape.getSrc(); 365 auto sourceType = cast<MemRefType>(source.getType()); 366 auto [strides, offset] = sourceType.getStridesAndOffset(); 367 368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) 369 ? origStrides[groupId] 370 : builder.getIndexAttr(strides[groupId]); 371 372 // Apply the original stride to all the strides. 373 int64_t doneStrideIdx = 0; 374 // If we saw a dynamic dimension, we need to fix-up all the strides up to 375 // that dimension with the dynamic size. 376 if (dynSizeIdx) { 377 int64_t productOfAllStaticSizes = currentStride; 378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) && 379 "We shouldn't be able to change dynamicity"); 380 OpFoldResult origSize = origSizes[groupId]; 381 382 AffineExpr s0 = builder.getAffineSymbolExpr(0); 383 AffineExpr s1 = builder.getAffineSymbolExpr(1); 384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { 385 int64_t baseExpandedStride = 386 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx])) 387 .getInt(); 388 expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( 389 builder, expandShape.getLoc(), 390 (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1, 391 {origSize, origStride}); 392 } 393 } 394 395 // Now apply the origStride to the remaining dimensions. 396 AffineExpr s0 = builder.getAffineSymbolExpr(0); 397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) { 398 int64_t baseExpandedStride = 399 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx])) 400 .getInt(); 401 expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( 402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); 403 } 404 405 return expandedStrides; 406 } 407 408 /// Produce an OpFoldResult object with \p builder at \p loc representing 409 /// `prod(valueOrConstant#i, for i in {indices})`, 410 /// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false, 411 /// values[i] otherwise. 412 /// 413 /// \pre for all index in indices: index < values.size() 414 /// \pre for all index in indices: index < maybeConstants.size() 415 static OpFoldResult 416 getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc, 417 ArrayRef<int64_t> maybeConstants, 418 ArrayRef<OpFoldResult> values, 419 llvm::function_ref<bool(int64_t)> isDynamic) { 420 AffineExpr productOfValues = builder.getAffineConstantExpr(1); 421 SmallVector<OpFoldResult> inputValues; 422 unsigned numberOfSymbols = 0; 423 unsigned groupSize = indices.size(); 424 for (unsigned i = 0; i < groupSize; ++i) { 425 productOfValues = 426 productOfValues * builder.getAffineSymbolExpr(numberOfSymbols++); 427 unsigned srcIdx = indices[i]; 428 int64_t maybeConstant = maybeConstants[srcIdx]; 429 430 inputValues.push_back(isDynamic(maybeConstant) 431 ? values[srcIdx] 432 : builder.getIndexAttr(maybeConstant)); 433 } 434 435 return makeComposedFoldedAffineApply(builder, loc, productOfValues, 436 inputValues); 437 } 438 439 /// Compute the collapsed size of the given \p collpaseShape for the 440 /// \p groupId-th reassociation group. 441 /// \p origSizes hold the sizes of the source shape as values. 442 /// This is used to compute the new sizes in cases of dynamic shapes. 443 /// 444 /// Conceptually this helper function computes: 445 /// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`. 446 /// 447 /// \post result.size() == 1, in other words, each group collapse to one 448 /// dimension. 449 /// 450 /// TODO: Move this utility function directly within CollapseShapeOp. For now, 451 /// this is not possible because this function uses the Affine dialect and the 452 /// MemRef dialect cannot depend on the Affine dialect. 453 static SmallVector<OpFoldResult> 454 getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, 455 ArrayRef<OpFoldResult> origSizes, unsigned groupId) { 456 SmallVector<OpFoldResult> collapsedSize; 457 458 MemRefType collapseShapeType = collapseShape.getResultType(); 459 460 uint64_t size = collapseShapeType.getDimSize(groupId); 461 if (!ShapedType::isDynamic(size)) { 462 collapsedSize.push_back(builder.getIndexAttr(size)); 463 return collapsedSize; 464 } 465 466 // We are dealing with a dynamic size. 467 // Build the affine expr of the product of the original sizes involved in that 468 // group. 469 Value source = collapseShape.getSrc(); 470 auto sourceType = cast<MemRefType>(source.getType()); 471 472 SmallVector<int64_t, 2> reassocGroup = 473 collapseShape.getReassociationIndices()[groupId]; 474 475 collapsedSize.push_back(getProductOfValues( 476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(), 477 origSizes, ShapedType::isDynamic)); 478 479 return collapsedSize; 480 } 481 482 /// Compute the collapsed stride of the given \p collpaseShape for the 483 /// \p groupId-th reassociation group. 484 /// \p origStrides and \p origSizes hold respectively the strides and sizes 485 /// of the source shape as values. 486 /// This is used to compute the strides in cases of dynamic shapes and/or 487 /// dynamic stride for this reassociation group. 488 /// 489 /// Conceptually this helper function returns the stride of the inner most 490 /// dimension of that group in the original shape. 491 /// 492 /// \post result.size() == 1, in other words, each group collapse to one 493 /// dimension. 494 static SmallVector<OpFoldResult> 495 getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, 496 ArrayRef<OpFoldResult> origSizes, 497 ArrayRef<OpFoldResult> origStrides, unsigned groupId) { 498 SmallVector<int64_t, 2> reassocGroup = 499 collapseShape.getReassociationIndices()[groupId]; 500 assert(!reassocGroup.empty() && 501 "Reassociation group should have at least one dimension"); 502 503 Value source = collapseShape.getSrc(); 504 auto sourceType = cast<MemRefType>(source.getType()); 505 506 auto [strides, offset] = sourceType.getStridesAndOffset(); 507 508 SmallVector<OpFoldResult> groupStrides; 509 ArrayRef<int64_t> srcShape = sourceType.getShape(); 510 511 OpFoldResult lastValidStride = nullptr; 512 for (int64_t currentDim : reassocGroup) { 513 // Skip size-of-1 dimensions, since right now their strides may be 514 // meaningless. 515 // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless 516 // they are truly contiguous. When they are truly contiguous, we shouldn't 517 // need to skip them. 518 if (srcShape[currentDim] == 1) 519 continue; 520 521 int64_t currentStride = strides[currentDim]; 522 lastValidStride = ShapedType::isDynamic(currentStride) 523 ? origStrides[currentDim] 524 : builder.getIndexAttr(currentStride); 525 } 526 if (!lastValidStride) { 527 // We're dealing with a 1x1x...x1 shape. The stride is meaningless, 528 // but we still have to make the type system happy. 529 MemRefType collapsedType = collapseShape.getResultType(); 530 auto [collapsedStrides, collapsedOffset] = 531 collapsedType.getStridesAndOffset(); 532 int64_t finalStride = collapsedStrides[groupId]; 533 if (ShapedType::isDynamic(finalStride)) { 534 // Look for a dynamic stride. At this point we don't know which one is 535 // desired, but they are all equally good/bad. 536 for (int64_t currentDim : reassocGroup) { 537 assert(srcShape[currentDim] == 1 && 538 "We should be dealing with 1x1x...x1"); 539 540 if (ShapedType::isDynamic(strides[currentDim])) 541 return {origStrides[currentDim]}; 542 } 543 llvm_unreachable("We should have found a dynamic stride"); 544 } 545 return {builder.getIndexAttr(finalStride)}; 546 } 547 548 return {lastValidStride}; 549 } 550 551 /// From `reshape_like(memref, subSizes, subStrides))` compute 552 /// 553 /// \verbatim 554 /// baseBuffer, baseOffset, baseSizes, baseStrides = 555 /// extract_strided_metadata(memref) 556 /// strides#i = baseStrides#i * subStrides#i 557 /// sizes = subSizes 558 /// \endverbatim 559 /// 560 /// and return {baseBuffer, baseOffset, sizes, strides} 561 template <typename ReassociativeReshapeLikeOp> 562 static FailureOr<StridedMetadata> resolveReshapeStridedMetadata( 563 RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape, 564 function_ref<SmallVector<OpFoldResult>( 565 ReassociativeReshapeLikeOp, OpBuilder &, 566 ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)> 567 getReshapedSizes, 568 function_ref<SmallVector<OpFoldResult>( 569 ReassociativeReshapeLikeOp, OpBuilder &, 570 ArrayRef<OpFoldResult> /*origSizes*/, 571 ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)> 572 getReshapedStrides) { 573 // Build a plain extract_strided_metadata(memref) from 574 // extract_strided_metadata(reassociative_reshape_like(memref)). 575 Location origLoc = reshape.getLoc(); 576 Value source = reshape.getSrc(); 577 auto sourceType = cast<MemRefType>(source.getType()); 578 unsigned sourceRank = sourceType.getRank(); 579 580 auto newExtractStridedMetadata = 581 rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source); 582 583 // Collect statically known information. 584 auto [strides, offset] = sourceType.getStridesAndOffset(); 585 MemRefType reshapeType = reshape.getResultType(); 586 unsigned reshapeRank = reshapeType.getRank(); 587 588 OpFoldResult offsetOfr = 589 ShapedType::isDynamic(offset) 590 ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) 591 : rewriter.getIndexAttr(offset); 592 593 // Get the special case of 0-D out of the way. 594 if (sourceRank == 0) { 595 SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1)); 596 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, 597 /*sizes=*/ones, /*strides=*/ones}; 598 } 599 600 SmallVector<OpFoldResult> finalSizes; 601 finalSizes.reserve(reshapeRank); 602 SmallVector<OpFoldResult> finalStrides; 603 finalStrides.reserve(reshapeRank); 604 605 // Compute the reshaped strides and sizes from the base strides and sizes. 606 SmallVector<OpFoldResult> origSizes = 607 getAsOpFoldResult(newExtractStridedMetadata.getSizes()); 608 SmallVector<OpFoldResult> origStrides = 609 getAsOpFoldResult(newExtractStridedMetadata.getStrides()); 610 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); 611 for (; idx != endIdx; ++idx) { 612 SmallVector<OpFoldResult> reshapedSizes = 613 getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); 614 SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides( 615 reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); 616 617 unsigned groupSize = reshapedSizes.size(); 618 for (unsigned i = 0; i < groupSize; ++i) { 619 finalSizes.push_back(reshapedSizes[i]); 620 finalStrides.push_back(reshapedStrides[i]); 621 } 622 } 623 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) || 624 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) && 625 "We should have visited all the input dimensions"); 626 assert(finalSizes.size() == reshapeRank && 627 "We should have populated all the values"); 628 629 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, 630 finalSizes, finalStrides}; 631 } 632 633 /// Replace `baseBuffer, offset, sizes, strides = 634 /// extract_strided_metadata(reshapeLike(memref))` 635 /// With 636 /// 637 /// \verbatim 638 /// baseBuffer, offset, baseSizes, baseStrides = 639 /// extract_strided_metadata(memref) 640 /// sizes = getReshapedSizes(reshapeLike) 641 /// strides = getReshapedStrides(reshapeLike) 642 /// \endverbatim 643 /// 644 /// 645 /// Notice that `baseBuffer` and `offset` are unchanged. 646 /// 647 /// In other words, get rid of the expand_shape in that expression and 648 /// materialize its effects on the sizes and the strides using affine apply. 649 template <typename ReassociativeReshapeLikeOp, 650 SmallVector<OpFoldResult> (*getReshapedSizes)( 651 ReassociativeReshapeLikeOp, OpBuilder &, 652 ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/), 653 SmallVector<OpFoldResult> (*getReshapedStrides)( 654 ReassociativeReshapeLikeOp, OpBuilder &, 655 ArrayRef<OpFoldResult> /*origSizes*/, 656 ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)> 657 struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> { 658 public: 659 using OpRewritePattern<ReassociativeReshapeLikeOp>::OpRewritePattern; 660 661 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape, 662 PatternRewriter &rewriter) const override { 663 FailureOr<StridedMetadata> stridedMetadata = 664 resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>( 665 rewriter, reshape, getReshapedSizes, getReshapedStrides); 666 if (failed(stridedMetadata)) { 667 return rewriter.notifyMatchFailure(reshape, 668 "failed to resolve reshape metadata"); 669 } 670 671 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( 672 reshape, reshape.getType(), stridedMetadata->basePtr, 673 stridedMetadata->offset, stridedMetadata->sizes, 674 stridedMetadata->strides); 675 return success(); 676 } 677 }; 678 679 /// Pattern to replace `extract_strided_metadata(collapse_shape)` 680 /// With 681 /// 682 /// \verbatim 683 /// baseBuffer, baseOffset, baseSizes, baseStrides = 684 /// extract_strided_metadata(memref) 685 /// strides#i = baseStrides#i * subSizes#i 686 /// offset = baseOffset + sum(subOffset#i * baseStrides#i) 687 /// sizes = subSizes 688 /// \verbatim 689 /// 690 /// with `baseBuffer`, `offset`, `sizes` and `strides` being 691 /// the replacements for the original `extract_strided_metadata`. 692 struct ExtractStridedMetadataOpCollapseShapeFolder 693 : OpRewritePattern<memref::ExtractStridedMetadataOp> { 694 using OpRewritePattern::OpRewritePattern; 695 696 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, 697 PatternRewriter &rewriter) const override { 698 auto collapseShapeOp = 699 op.getSource().getDefiningOp<memref::CollapseShapeOp>(); 700 if (!collapseShapeOp) 701 return failure(); 702 703 FailureOr<StridedMetadata> stridedMetadata = 704 resolveReshapeStridedMetadata<memref::CollapseShapeOp>( 705 rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride); 706 if (failed(stridedMetadata)) { 707 return rewriter.notifyMatchFailure( 708 op, 709 "failed to resolve metadata in terms of source collapse_shape op"); 710 } 711 712 Location loc = collapseShapeOp.getLoc(); 713 SmallVector<Value> results; 714 results.push_back(stridedMetadata->basePtr); 715 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, 716 stridedMetadata->offset)); 717 results.append( 718 getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); 719 results.append(getValueOrCreateConstantIndexOp(rewriter, loc, 720 stridedMetadata->strides)); 721 rewriter.replaceOp(op, results); 722 return success(); 723 } 724 }; 725 726 /// Pattern to replace `extract_strided_metadata(expand_shape)` 727 /// with the results of computing the sizes and strides on the expanded shape 728 /// and dividing up dimensions into static and dynamic parts as needed. 729 struct ExtractStridedMetadataOpExpandShapeFolder 730 : OpRewritePattern<memref::ExtractStridedMetadataOp> { 731 using OpRewritePattern::OpRewritePattern; 732 733 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, 734 PatternRewriter &rewriter) const override { 735 auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>(); 736 if (!expandShapeOp) 737 return failure(); 738 739 FailureOr<StridedMetadata> stridedMetadata = 740 resolveReshapeStridedMetadata<memref::ExpandShapeOp>( 741 rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides); 742 if (failed(stridedMetadata)) { 743 return rewriter.notifyMatchFailure( 744 op, "failed to resolve metadata in terms of source expand_shape op"); 745 } 746 747 Location loc = expandShapeOp.getLoc(); 748 SmallVector<Value> results; 749 results.push_back(stridedMetadata->basePtr); 750 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, 751 stridedMetadata->offset)); 752 results.append( 753 getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); 754 results.append(getValueOrCreateConstantIndexOp(rewriter, loc, 755 stridedMetadata->strides)); 756 rewriter.replaceOp(op, results); 757 return success(); 758 } 759 }; 760 761 /// Replace `base, offset, sizes, strides = 762 /// extract_strided_metadata(allocLikeOp)` 763 /// 764 /// With 765 /// 766 /// ``` 767 /// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy> 768 /// offset = 0 769 /// sizes = allocSizes 770 /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) 771 /// ``` 772 /// 773 /// The transformation only applies if the allocLikeOp has been normalized. 774 /// In other words, the affine_map must be an identity. 775 template <typename AllocLikeOp> 776 struct ExtractStridedMetadataOpAllocFolder 777 : public OpRewritePattern<memref::ExtractStridedMetadataOp> { 778 public: 779 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern; 780 781 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, 782 PatternRewriter &rewriter) const override { 783 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>(); 784 if (!allocLikeOp) 785 return failure(); 786 787 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType()); 788 if (!memRefType.getLayout().isIdentity()) 789 return rewriter.notifyMatchFailure( 790 allocLikeOp, "alloc-like operations should have been normalized"); 791 792 Location loc = op.getLoc(); 793 int rank = memRefType.getRank(); 794 795 // Collect the sizes. 796 ValueRange dynamic = allocLikeOp.getDynamicSizes(); 797 SmallVector<OpFoldResult> sizes; 798 sizes.reserve(rank); 799 unsigned dynamicPos = 0; 800 for (int64_t size : memRefType.getShape()) { 801 if (ShapedType::isDynamic(size)) 802 sizes.push_back(dynamic[dynamicPos++]); 803 else 804 sizes.push_back(rewriter.getIndexAttr(size)); 805 } 806 807 // Strides (just creates identity strides). 808 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); 809 AffineExpr expr = rewriter.getAffineConstantExpr(1); 810 unsigned symbolNumber = 0; 811 for (int i = rank - 2; i >= 0; --i) { 812 expr = expr * rewriter.getAffineSymbolExpr(symbolNumber++); 813 assert(i + 1 + symbolNumber == sizes.size() && 814 "The ArrayRef should encompass the last #symbolNumber sizes"); 815 ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber); 816 strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr, 817 sizesInvolvedInStride); 818 } 819 820 // Put all the values together to replace the results. 821 SmallVector<Value> results; 822 results.reserve(rank * 2 + 2); 823 824 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType()); 825 int64_t offset = 0; 826 if (op.getBaseBuffer().use_empty()) { 827 results.push_back(nullptr); 828 } else { 829 if (allocLikeOp.getType() == baseBufferType) 830 results.push_back(allocLikeOp); 831 else 832 results.push_back(rewriter.create<memref::ReinterpretCastOp>( 833 loc, baseBufferType, allocLikeOp, offset, 834 /*sizes=*/ArrayRef<int64_t>(), 835 /*strides=*/ArrayRef<int64_t>())); 836 } 837 838 // Offset. 839 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset)); 840 841 for (OpFoldResult size : sizes) 842 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size)); 843 844 for (OpFoldResult stride : strides) 845 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride)); 846 847 rewriter.replaceOp(op, results); 848 return success(); 849 } 850 }; 851 852 /// Replace `base, offset, sizes, strides = 853 /// extract_strided_metadata(get_global)` 854 /// 855 /// With 856 /// 857 /// ``` 858 /// base = reinterpret_cast get_global to a flat memref<eltTy> 859 /// offset = 0 860 /// sizes = allocSizes 861 /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) 862 /// ``` 863 /// 864 /// It is expected that the memref.get_global op has static shapes 865 /// and identity affine_map for the layout. 866 struct ExtractStridedMetadataOpGetGlobalFolder 867 : public OpRewritePattern<memref::ExtractStridedMetadataOp> { 868 public: 869 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern; 870 871 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, 872 PatternRewriter &rewriter) const override { 873 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>(); 874 if (!getGlobalOp) 875 return failure(); 876 877 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType()); 878 if (!memRefType.getLayout().isIdentity()) { 879 return rewriter.notifyMatchFailure( 880 getGlobalOp, 881 "get-global operation result should have been normalized"); 882 } 883 884 Location loc = op.getLoc(); 885 int rank = memRefType.getRank(); 886 887 // Collect the sizes. 888 ArrayRef<int64_t> sizes = memRefType.getShape(); 889 assert(!llvm::any_of(sizes, ShapedType::isDynamic) && 890 "unexpected dynamic shape for result of `memref.get_global` op"); 891 892 // Strides (just creates identity strides). 893 SmallVector<int64_t> strides = computeSuffixProduct(sizes); 894 895 // Put all the values together to replace the results. 896 SmallVector<Value> results; 897 results.reserve(rank * 2 + 2); 898 899 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType()); 900 int64_t offset = 0; 901 if (getGlobalOp.getType() == baseBufferType) 902 results.push_back(getGlobalOp); 903 else 904 results.push_back(rewriter.create<memref::ReinterpretCastOp>( 905 loc, baseBufferType, getGlobalOp, offset, 906 /*sizes=*/ArrayRef<int64_t>(), 907 /*strides=*/ArrayRef<int64_t>())); 908 909 // Offset. 910 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset)); 911 912 for (auto size : sizes) 913 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size)); 914 915 for (auto stride : strides) 916 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride)); 917 918 rewriter.replaceOp(op, results); 919 return success(); 920 } 921 }; 922 923 /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the 924 /// source of the ViewLikeOp. 925 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp 926 : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> { 927 using OpRewritePattern::OpRewritePattern; 928 929 LogicalResult 930 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, 931 PatternRewriter &rewriter) const override { 932 auto viewLikeOp = 933 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>(); 934 if (!viewLikeOp) 935 return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); 936 rewriter.modifyOpInPlace(extractOp, [&]() { 937 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); 938 }); 939 return success(); 940 } 941 }; 942 943 /// Replace `base, offset, sizes, strides = 944 /// extract_strided_metadata( 945 /// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))` 946 /// With 947 /// ``` 948 /// base, ... = extract_strided_metadata(src) 949 /// offset = srcOffset 950 /// sizes = srcSizes 951 /// strides = srcStrides 952 /// ``` 953 /// 954 /// In other words, consume the `reinterpret_cast` and apply its effects 955 /// on the offset, sizes, and strides. 956 class ExtractStridedMetadataOpReinterpretCastFolder 957 : public OpRewritePattern<memref::ExtractStridedMetadataOp> { 958 using OpRewritePattern::OpRewritePattern; 959 960 LogicalResult 961 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, 962 PatternRewriter &rewriter) const override { 963 auto reinterpretCastOp = extractStridedMetadataOp.getSource() 964 .getDefiningOp<memref::ReinterpretCastOp>(); 965 if (!reinterpretCastOp) 966 return failure(); 967 968 Location loc = extractStridedMetadataOp.getLoc(); 969 // Check if the source is suitable for extract_strided_metadata. 970 SmallVector<Type> inferredReturnTypes; 971 if (failed(extractStridedMetadataOp.inferReturnTypes( 972 rewriter.getContext(), loc, {reinterpretCastOp.getSource()}, 973 /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, 974 inferredReturnTypes))) 975 return rewriter.notifyMatchFailure( 976 reinterpretCastOp, "reinterpret_cast source's type is incompatible"); 977 978 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType()); 979 unsigned rank = memrefType.getRank(); 980 SmallVector<OpFoldResult> results; 981 results.resize_for_overwrite(rank * 2 + 2); 982 983 auto newExtractStridedMetadata = 984 rewriter.create<memref::ExtractStridedMetadataOp>( 985 loc, reinterpretCastOp.getSource()); 986 987 // Register the base_buffer. 988 results[0] = newExtractStridedMetadata.getBaseBuffer(); 989 990 // Register the new offset. 991 results[1] = getValueOrCreateConstantIndexOp( 992 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]); 993 994 const unsigned sizeStartIdx = 2; 995 const unsigned strideStartIdx = sizeStartIdx + rank; 996 997 SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes(); 998 SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides(); 999 for (unsigned i = 0; i < rank; ++i) { 1000 results[sizeStartIdx + i] = sizes[i]; 1001 results[strideStartIdx + i] = strides[i]; 1002 } 1003 rewriter.replaceOp(extractStridedMetadataOp, 1004 getValueOrCreateConstantIndexOp(rewriter, loc, results)); 1005 return success(); 1006 } 1007 }; 1008 1009 /// Replace `base, offset, sizes, strides = 1010 /// extract_strided_metadata( 1011 /// cast(src) to dstTy)` 1012 /// With 1013 /// ``` 1014 /// base, ... = extract_strided_metadata(src) 1015 /// offset = !dstTy.srcOffset.isDynamic() 1016 /// ? dstTy.srcOffset 1017 /// : extract_strided_metadata(src).offset 1018 /// sizes = for each srcSize in dstTy.srcSizes: 1019 /// !srcSize.isDynamic() 1020 /// ? srcSize 1021 // : extract_strided_metadata(src).sizes[i] 1022 /// strides = for each srcStride in dstTy.srcStrides: 1023 /// !srcStrides.isDynamic() 1024 /// ? srcStrides 1025 /// : extract_strided_metadata(src).strides[i] 1026 /// ``` 1027 /// 1028 /// In other words, consume the `cast` and apply its effects 1029 /// on the offset, sizes, and strides or compute them directly from `src`. 1030 class ExtractStridedMetadataOpCastFolder 1031 : public OpRewritePattern<memref::ExtractStridedMetadataOp> { 1032 using OpRewritePattern::OpRewritePattern; 1033 1034 LogicalResult 1035 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, 1036 PatternRewriter &rewriter) const override { 1037 Value source = extractStridedMetadataOp.getSource(); 1038 auto castOp = source.getDefiningOp<memref::CastOp>(); 1039 if (!castOp) 1040 return failure(); 1041 1042 Location loc = extractStridedMetadataOp.getLoc(); 1043 // Check if the source is suitable for extract_strided_metadata. 1044 SmallVector<Type> inferredReturnTypes; 1045 if (failed(extractStridedMetadataOp.inferReturnTypes( 1046 rewriter.getContext(), loc, {castOp.getSource()}, 1047 /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, 1048 inferredReturnTypes))) 1049 return rewriter.notifyMatchFailure(castOp, 1050 "cast source's type is incompatible"); 1051 1052 auto memrefType = cast<MemRefType>(source.getType()); 1053 unsigned rank = memrefType.getRank(); 1054 SmallVector<OpFoldResult> results; 1055 results.resize_for_overwrite(rank * 2 + 2); 1056 1057 auto newExtractStridedMetadata = 1058 rewriter.create<memref::ExtractStridedMetadataOp>(loc, 1059 castOp.getSource()); 1060 1061 // Register the base_buffer. 1062 results[0] = newExtractStridedMetadata.getBaseBuffer(); 1063 1064 auto getConstantOrValue = [&rewriter](int64_t constant, 1065 OpFoldResult ofr) -> OpFoldResult { 1066 return !ShapedType::isDynamic(constant) 1067 ? OpFoldResult(rewriter.getIndexAttr(constant)) 1068 : ofr; 1069 }; 1070 1071 auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset(); 1072 assert(sourceStrides.size() == rank && "unexpected number of strides"); 1073 1074 // Register the new offset. 1075 results[1] = 1076 getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset()); 1077 1078 const unsigned sizeStartIdx = 2; 1079 const unsigned strideStartIdx = sizeStartIdx + rank; 1080 ArrayRef<int64_t> sourceSizes = memrefType.getShape(); 1081 1082 SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes(); 1083 SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides(); 1084 for (unsigned i = 0; i < rank; ++i) { 1085 results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]); 1086 results[strideStartIdx + i] = 1087 getConstantOrValue(sourceStrides[i], strides[i]); 1088 } 1089 rewriter.replaceOp(extractStridedMetadataOp, 1090 getValueOrCreateConstantIndexOp(rewriter, loc, results)); 1091 return success(); 1092 } 1093 }; 1094 1095 /// Replace `base, offset, sizes, strides = extract_strided_metadata( 1096 /// memory_space_cast(src) to dstTy)` 1097 /// with 1098 /// ``` 1099 /// oldBase, offset, sizes, strides = extract_strided_metadata(src) 1100 /// destBaseTy = type(oldBase) with memory space from destTy 1101 /// base = memory_space_cast(oldBase) to destBaseTy 1102 /// ``` 1103 /// 1104 /// In other words, propagate metadata extraction accross memory space casts. 1105 class ExtractStridedMetadataOpMemorySpaceCastFolder 1106 : public OpRewritePattern<memref::ExtractStridedMetadataOp> { 1107 using OpRewritePattern::OpRewritePattern; 1108 1109 LogicalResult 1110 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, 1111 PatternRewriter &rewriter) const override { 1112 Location loc = extractStridedMetadataOp.getLoc(); 1113 Value source = extractStridedMetadataOp.getSource(); 1114 auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>(); 1115 if (!memSpaceCastOp) 1116 return failure(); 1117 auto newExtractStridedMetadata = 1118 rewriter.create<memref::ExtractStridedMetadataOp>( 1119 loc, memSpaceCastOp.getSource()); 1120 SmallVector<Value> results(newExtractStridedMetadata.getResults()); 1121 // As with most other strided metadata rewrite patterns, don't introduce 1122 // a use of the base pointer where non existed. This needs to happen here, 1123 // as opposed to in later dead-code elimination, because these patterns are 1124 // sometimes used during dialect conversion (see EmulateNarrowType, for 1125 // example), so adding spurious usages would cause a pre-legalization value 1126 // to be live that would be dead had this pattern not run. 1127 if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) { 1128 auto baseBuffer = results[0]; 1129 auto baseBufferType = cast<MemRefType>(baseBuffer.getType()); 1130 MemRefType::Builder newTypeBuilder(baseBufferType); 1131 newTypeBuilder.setMemorySpace( 1132 memSpaceCastOp.getResult().getType().getMemorySpace()); 1133 results[0] = rewriter.create<memref::MemorySpaceCastOp>( 1134 loc, Type{newTypeBuilder}, baseBuffer); 1135 } else { 1136 results[0] = nullptr; 1137 } 1138 rewriter.replaceOp(extractStridedMetadataOp, results); 1139 return success(); 1140 } 1141 }; 1142 1143 /// Replace `base, offset = 1144 /// extract_strided_metadata(extract_strided_metadata(src)#0)` 1145 /// With 1146 /// ``` 1147 /// base, ... = extract_strided_metadata(src) 1148 /// offset = 0 1149 /// ``` 1150 class ExtractStridedMetadataOpExtractStridedMetadataFolder 1151 : public OpRewritePattern<memref::ExtractStridedMetadataOp> { 1152 using OpRewritePattern::OpRewritePattern; 1153 1154 LogicalResult 1155 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, 1156 PatternRewriter &rewriter) const override { 1157 auto sourceExtractStridedMetadataOp = 1158 extractStridedMetadataOp.getSource() 1159 .getDefiningOp<memref::ExtractStridedMetadataOp>(); 1160 if (!sourceExtractStridedMetadataOp) 1161 return failure(); 1162 Location loc = extractStridedMetadataOp.getLoc(); 1163 rewriter.replaceOp(extractStridedMetadataOp, 1164 {sourceExtractStridedMetadataOp.getBaseBuffer(), 1165 getValueOrCreateConstantIndexOp( 1166 rewriter, loc, rewriter.getIndexAttr(0))}); 1167 return success(); 1168 } 1169 }; 1170 } // namespace 1171 1172 void memref::populateExpandStridedMetadataPatterns( 1173 RewritePatternSet &patterns) { 1174 patterns.add<SubviewFolder, 1175 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes, 1176 getExpandedStrides>, 1177 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize, 1178 getCollapsedStride>, 1179 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>, 1180 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>, 1181 ExtractStridedMetadataOpCollapseShapeFolder, 1182 ExtractStridedMetadataOpExpandShapeFolder, 1183 ExtractStridedMetadataOpGetGlobalFolder, 1184 RewriteExtractAlignedPointerAsIndexOfViewLikeOp, 1185 ExtractStridedMetadataOpReinterpretCastFolder, 1186 ExtractStridedMetadataOpSubviewFolder, 1187 ExtractStridedMetadataOpCastFolder, 1188 ExtractStridedMetadataOpMemorySpaceCastFolder, 1189 ExtractStridedMetadataOpExtractStridedMetadataFolder>( 1190 patterns.getContext()); 1191 } 1192 1193 void memref::populateResolveExtractStridedMetadataPatterns( 1194 RewritePatternSet &patterns) { 1195 patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>, 1196 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>, 1197 ExtractStridedMetadataOpCollapseShapeFolder, 1198 ExtractStridedMetadataOpExpandShapeFolder, 1199 ExtractStridedMetadataOpGetGlobalFolder, 1200 ExtractStridedMetadataOpSubviewFolder, 1201 RewriteExtractAlignedPointerAsIndexOfViewLikeOp, 1202 ExtractStridedMetadataOpReinterpretCastFolder, 1203 ExtractStridedMetadataOpCastFolder, 1204 ExtractStridedMetadataOpMemorySpaceCastFolder, 1205 ExtractStridedMetadataOpExtractStridedMetadataFolder>( 1206 patterns.getContext()); 1207 } 1208 1209 //===----------------------------------------------------------------------===// 1210 // Pass registration 1211 //===----------------------------------------------------------------------===// 1212 1213 namespace { 1214 1215 struct ExpandStridedMetadataPass final 1216 : public memref::impl::ExpandStridedMetadataBase< 1217 ExpandStridedMetadataPass> { 1218 void runOnOperation() override; 1219 }; 1220 1221 } // namespace 1222 1223 void ExpandStridedMetadataPass::runOnOperation() { 1224 RewritePatternSet patterns(&getContext()); 1225 memref::populateExpandStridedMetadataPatterns(patterns); 1226 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 1227 } 1228 1229 std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() { 1230 return std::make_unique<ExpandStridedMetadataPass>(); 1231 } 1232