xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- ExpandStridedMetadata.cpp - Simplify this operation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// The pass expands memref operations that modify the metadata of a memref
10 /// (sizes, offset, strides) into a sequence of easier to analyze constructs.
11 /// In particular, this pass transforms operations into explicit sequence of
12 /// operations that model the effect of this operation on the different
13 /// metadata. This pass uses affine constructs to materialize these effects.
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/Utils/Utils.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
20 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
21 #include "mlir/Dialect/Utils/IndexingUtils.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include <optional>
28 
29 namespace mlir {
30 namespace memref {
31 #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA
32 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
33 } // namespace memref
34 } // namespace mlir
35 
36 using namespace mlir;
37 using namespace mlir::affine;
38 
39 namespace {
40 
41 struct StridedMetadata {
42   Value basePtr;
43   OpFoldResult offset;
44   SmallVector<OpFoldResult> sizes;
45   SmallVector<OpFoldResult> strides;
46 };
47 
48 /// From `subview(memref, subOffset, subSizes, subStrides))` compute
49 ///
50 /// \verbatim
51 /// baseBuffer, baseOffset, baseSizes, baseStrides =
52 ///     extract_strided_metadata(memref)
53 /// strides#i = baseStrides#i * subStrides#i
54 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
55 /// sizes = subSizes
56 /// \endverbatim
57 ///
58 /// and return {baseBuffer, offset, sizes, strides}
59 static FailureOr<StridedMetadata>
60 resolveSubviewStridedMetadata(RewriterBase &rewriter,
61                               memref::SubViewOp subview) {
62   // Build a plain extract_strided_metadata(memref) from subview(memref).
63   Location origLoc = subview.getLoc();
64   Value source = subview.getSource();
65   auto sourceType = cast<MemRefType>(source.getType());
66   unsigned sourceRank = sourceType.getRank();
67 
68   auto newExtractStridedMetadata =
69       rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
70 
71   auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
72 #ifndef NDEBUG
73   auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset();
74 #endif // NDEBUG
75 
76   // Compute the new strides and offset from the base strides and offset:
77   // newStride#i = baseStride#i * subStride#i
78   // offset = baseOffset + sum(subOffsets#i * newStrides#i)
79   SmallVector<OpFoldResult> strides;
80   SmallVector<OpFoldResult> subStrides = subview.getMixedStrides();
81   auto origStrides = newExtractStridedMetadata.getStrides();
82 
83   // Hold the affine symbols and values for the computation of the offset.
84   SmallVector<OpFoldResult> values(2 * sourceRank + 1);
85   SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
86 
87   bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
88   AffineExpr expr = symbols.front();
89   values[0] = ShapedType::isDynamic(sourceOffset)
90                   ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
91                   : rewriter.getIndexAttr(sourceOffset);
92   SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets();
93 
94   AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
95   AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
96   for (unsigned i = 0; i < sourceRank; ++i) {
97     // Compute the stride.
98     OpFoldResult origStride =
99         ShapedType::isDynamic(sourceStrides[i])
100             ? origStrides[i]
101             : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i]));
102     strides.push_back(makeComposedFoldedAffineApply(
103         rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
104 
105     // Build up the computation of the offset.
106     unsigned baseIdxForDim = 1 + 2 * i;
107     unsigned subOffsetForDim = baseIdxForDim;
108     unsigned origStrideForDim = baseIdxForDim + 1;
109     expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110     values[subOffsetForDim] = subOffsets[i];
111     values[origStrideForDim] = origStride;
112   }
113 
114   // Compute the offset.
115   OpFoldResult finalOffset =
116       makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
117 #ifndef NDEBUG
118   // Assert that the computed offset matches the offset of the result type of
119   // the subview op (if both are static).
120   std::optional<int64_t> computedOffset = getConstantIntValue(finalOffset);
121   if (computedOffset && !ShapedType::isDynamic(resultOffset))
122     assert(*computedOffset == resultOffset &&
123            "mismatch between computed offset and result type offset");
124 #endif // NDEBUG
125 
126   // The final result is  <baseBuffer, offset, sizes, strides>.
127   // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
128   // the values.
129   auto subType = cast<MemRefType>(subview.getType());
130   unsigned subRank = subType.getRank();
131 
132   // The sizes of the final type are defined directly by the input sizes of
133   // the subview.
134   // Moreover subviews can drop some dimensions, some strides and sizes may
135   // not end up in the final <base, offset, sizes, strides> value that we are
136   // replacing.
137   // Do the filtering here.
138   SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
139   llvm::SmallBitVector droppedDims = subview.getDroppedDims();
140 
141   SmallVector<OpFoldResult> finalSizes;
142   finalSizes.reserve(subRank);
143 
144   SmallVector<OpFoldResult> finalStrides;
145   finalStrides.reserve(subRank);
146 
147 #ifndef NDEBUG
148   // Iteration variable for result dimensions of the subview op.
149   int64_t j = 0;
150 #endif // NDEBUG
151   for (unsigned i = 0; i < sourceRank; ++i) {
152     if (droppedDims.test(i))
153       continue;
154 
155     finalSizes.push_back(subSizes[i]);
156     finalStrides.push_back(strides[i]);
157 #ifndef NDEBUG
158     // Assert that the computed stride matches the stride of the result type of
159     // the subview op (if both are static).
160     std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
161     if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
162       assert(*computedStride == resultStrides[j] &&
163              "mismatch between computed stride and result type stride");
164     ++j;
165 #endif // NDEBUG
166   }
167   assert(finalSizes.size() == subRank &&
168          "Should have populated all the values at this point");
169   return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170                          finalSizes, finalStrides};
171 }
172 
173 /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))`
174 /// With
175 ///
176 /// \verbatim
177 /// baseBuffer, baseOffset, baseSizes, baseStrides =
178 ///     extract_strided_metadata(memref)
179 /// strides#i = baseStrides#i * subSizes#i
180 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
181 /// sizes = subSizes
182 /// dst = reinterpret_cast baseBuffer, offset, sizes, strides
183 /// \endverbatim
184 ///
185 /// In other words, get rid of the subview in that expression and canonicalize
186 /// on its effects on the offset, the sizes, and the strides using affine.apply.
187 struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
188 public:
189   using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
190 
191   LogicalResult matchAndRewrite(memref::SubViewOp subview,
192                                 PatternRewriter &rewriter) const override {
193     FailureOr<StridedMetadata> stridedMetadata =
194         resolveSubviewStridedMetadata(rewriter, subview);
195     if (failed(stridedMetadata)) {
196       return rewriter.notifyMatchFailure(subview,
197                                          "failed to resolve subview metadata");
198     }
199 
200     rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
201         subview, subview.getType(), stridedMetadata->basePtr,
202         stridedMetadata->offset, stridedMetadata->sizes,
203         stridedMetadata->strides);
204     return success();
205   }
206 };
207 
208 /// Pattern to replace `extract_strided_metadata(subview)`
209 /// With
210 ///
211 /// \verbatim
212 /// baseBuffer, baseOffset, baseSizes, baseStrides =
213 ///     extract_strided_metadata(memref)
214 /// strides#i = baseStrides#i * subSizes#i
215 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
216 /// sizes = subSizes
217 /// \verbatim
218 ///
219 /// with `baseBuffer`, `offset`, `sizes` and `strides` being
220 /// the replacements for the original `extract_strided_metadata`.
221 struct ExtractStridedMetadataOpSubviewFolder
222     : OpRewritePattern<memref::ExtractStridedMetadataOp> {
223   using OpRewritePattern::OpRewritePattern;
224 
225   LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
226                                 PatternRewriter &rewriter) const override {
227     auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
228     if (!subviewOp)
229       return failure();
230 
231     FailureOr<StridedMetadata> stridedMetadata =
232         resolveSubviewStridedMetadata(rewriter, subviewOp);
233     if (failed(stridedMetadata)) {
234       return rewriter.notifyMatchFailure(
235           op, "failed to resolve metadata in terms of source subview op");
236     }
237     Location loc = subviewOp.getLoc();
238     SmallVector<Value> results;
239     results.reserve(subviewOp.getType().getRank() * 2 + 2);
240     results.push_back(stridedMetadata->basePtr);
241     results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
242                                                       stridedMetadata->offset));
243     results.append(
244         getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
245     results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
246                                                    stridedMetadata->strides));
247     rewriter.replaceOp(op, results);
248 
249     return success();
250   }
251 };
252 
253 /// Compute the expanded sizes of the given \p expandShape for the
254 /// \p groupId-th reassociation group.
255 /// \p origSizes hold the sizes of the source shape as values.
256 /// This is used to compute the new sizes in cases of dynamic shapes.
257 ///
258 /// sizes#i =
259 ///     baseSizes#groupId / product(expandShapeSizes#j,
260 ///                                  for j in group excluding reassIdx#i)
261 /// Where reassIdx#i is the reassociation index at index i in \p groupId.
262 ///
263 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
264 ///
265 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
266 /// this is not possible because this function uses the Affine dialect and the
267 /// MemRef dialect cannot depend on the Affine dialect.
268 static SmallVector<OpFoldResult>
269 getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
270                  ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
271   SmallVector<int64_t, 2> reassocGroup =
272       expandShape.getReassociationIndices()[groupId];
273   assert(!reassocGroup.empty() &&
274          "Reassociation group should have at least one dimension");
275 
276   unsigned groupSize = reassocGroup.size();
277   SmallVector<OpFoldResult> expandedSizes(groupSize);
278 
279   uint64_t productOfAllStaticSizes = 1;
280   std::optional<unsigned> dynSizeIdx;
281   MemRefType expandShapeType = expandShape.getResultType();
282 
283   // Fill up all the statically known sizes.
284   for (unsigned i = 0; i < groupSize; ++i) {
285     uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286     if (ShapedType::isDynamic(dimSize)) {
287       assert(!dynSizeIdx && "There must be at most one dynamic size per group");
288       dynSizeIdx = i;
289       continue;
290     }
291     productOfAllStaticSizes *= dimSize;
292     expandedSizes[i] = builder.getIndexAttr(dimSize);
293   }
294 
295   // Compute the dynamic size using the original size and all the other known
296   // static sizes:
297   // expandSize = origSize / productOfAllStaticSizes.
298   if (dynSizeIdx) {
299     AffineExpr s0 = builder.getAffineSymbolExpr(0);
300     expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
301         builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
302         origSizes[groupId]);
303   }
304 
305   return expandedSizes;
306 }
307 
308 /// Compute the expanded strides of the given \p expandShape for the
309 /// \p groupId-th reassociation group.
310 /// \p origStrides and \p origSizes hold respectively the strides and sizes
311 /// of the source shape as values.
312 /// This is used to compute the strides in cases of dynamic shapes and/or
313 /// dynamic stride for this reassociation group.
314 ///
315 /// strides#i =
316 ///     origStrides#reassDim * product(expandShapeSizes#j, for j in
317 ///                                    reassIdx#i+1..reassIdx#i+group.size-1)
318 ///
319 /// Where reassIdx#i is the reassociation index for at index i in \p groupId
320 /// and expandShapeSizes#j is either:
321 /// - The constant size at dimension j, derived directly from the result type of
322 ///   the expand_shape op, or
323 /// - An affine expression: baseSizes#reassDim / product of all constant sizes
324 ///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325 ///   element.)
326 ///
327 /// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
328 ///
329 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
330 /// this is not possible because this function uses the Affine dialect and the
331 /// MemRef dialect cannot depend on the Affine dialect.
332 SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
333                                              OpBuilder &builder,
334                                              ArrayRef<OpFoldResult> origSizes,
335                                              ArrayRef<OpFoldResult> origStrides,
336                                              unsigned groupId) {
337   SmallVector<int64_t, 2> reassocGroup =
338       expandShape.getReassociationIndices()[groupId];
339   assert(!reassocGroup.empty() &&
340          "Reassociation group should have at least one dimension");
341 
342   unsigned groupSize = reassocGroup.size();
343   MemRefType expandShapeType = expandShape.getResultType();
344 
345   std::optional<int64_t> dynSizeIdx;
346 
347   // Fill up the expanded strides, with the information we can deduce from the
348   // resulting shape.
349   uint64_t currentStride = 1;
350   SmallVector<OpFoldResult> expandedStrides(groupSize);
351   for (int i = groupSize - 1; i >= 0; --i) {
352     expandedStrides[i] = builder.getIndexAttr(currentStride);
353     uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354     if (ShapedType::isDynamic(dimSize)) {
355       assert(!dynSizeIdx && "There must be at most one dynamic size per group");
356       dynSizeIdx = i;
357       continue;
358     }
359 
360     currentStride *= dimSize;
361   }
362 
363   // Collect the statically known information about the original stride.
364   Value source = expandShape.getSrc();
365   auto sourceType = cast<MemRefType>(source.getType());
366   auto [strides, offset] = sourceType.getStridesAndOffset();
367 
368   OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369                                 ? origStrides[groupId]
370                                 : builder.getIndexAttr(strides[groupId]);
371 
372   // Apply the original stride to all the strides.
373   int64_t doneStrideIdx = 0;
374   // If we saw a dynamic dimension, we need to fix-up all the strides up to
375   // that dimension with the dynamic size.
376   if (dynSizeIdx) {
377     int64_t productOfAllStaticSizes = currentStride;
378     assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379            "We shouldn't be able to change dynamicity");
380     OpFoldResult origSize = origSizes[groupId];
381 
382     AffineExpr s0 = builder.getAffineSymbolExpr(0);
383     AffineExpr s1 = builder.getAffineSymbolExpr(1);
384     for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385       int64_t baseExpandedStride =
386           cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
387               .getInt();
388       expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
389           builder, expandShape.getLoc(),
390           (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
391           {origSize, origStride});
392     }
393   }
394 
395   // Now apply the origStride to the remaining dimensions.
396   AffineExpr s0 = builder.getAffineSymbolExpr(0);
397   for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398     int64_t baseExpandedStride =
399         cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
400             .getInt();
401     expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
402         builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
403   }
404 
405   return expandedStrides;
406 }
407 
408 /// Produce an OpFoldResult object with \p builder at \p loc representing
409 /// `prod(valueOrConstant#i, for i in {indices})`,
410 /// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
411 /// values[i] otherwise.
412 ///
413 /// \pre for all index in indices: index < values.size()
414 /// \pre for all index in indices: index < maybeConstants.size()
415 static OpFoldResult
416 getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
417                    ArrayRef<int64_t> maybeConstants,
418                    ArrayRef<OpFoldResult> values,
419                    llvm::function_ref<bool(int64_t)> isDynamic) {
420   AffineExpr productOfValues = builder.getAffineConstantExpr(1);
421   SmallVector<OpFoldResult> inputValues;
422   unsigned numberOfSymbols = 0;
423   unsigned groupSize = indices.size();
424   for (unsigned i = 0; i < groupSize; ++i) {
425     productOfValues =
426         productOfValues * builder.getAffineSymbolExpr(numberOfSymbols++);
427     unsigned srcIdx = indices[i];
428     int64_t maybeConstant = maybeConstants[srcIdx];
429 
430     inputValues.push_back(isDynamic(maybeConstant)
431                               ? values[srcIdx]
432                               : builder.getIndexAttr(maybeConstant));
433   }
434 
435   return makeComposedFoldedAffineApply(builder, loc, productOfValues,
436                                        inputValues);
437 }
438 
439 /// Compute the collapsed size of the given \p collpaseShape for the
440 /// \p groupId-th reassociation group.
441 /// \p origSizes hold the sizes of the source shape as values.
442 /// This is used to compute the new sizes in cases of dynamic shapes.
443 ///
444 /// Conceptually this helper function computes:
445 /// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`.
446 ///
447 /// \post result.size() == 1, in other words, each group collapse to one
448 /// dimension.
449 ///
450 /// TODO: Move this utility function directly within CollapseShapeOp. For now,
451 /// this is not possible because this function uses the Affine dialect and the
452 /// MemRef dialect cannot depend on the Affine dialect.
453 static SmallVector<OpFoldResult>
454 getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
455                  ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
456   SmallVector<OpFoldResult> collapsedSize;
457 
458   MemRefType collapseShapeType = collapseShape.getResultType();
459 
460   uint64_t size = collapseShapeType.getDimSize(groupId);
461   if (!ShapedType::isDynamic(size)) {
462     collapsedSize.push_back(builder.getIndexAttr(size));
463     return collapsedSize;
464   }
465 
466   // We are dealing with a dynamic size.
467   // Build the affine expr of the product of the original sizes involved in that
468   // group.
469   Value source = collapseShape.getSrc();
470   auto sourceType = cast<MemRefType>(source.getType());
471 
472   SmallVector<int64_t, 2> reassocGroup =
473       collapseShape.getReassociationIndices()[groupId];
474 
475   collapsedSize.push_back(getProductOfValues(
476       reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477       origSizes, ShapedType::isDynamic));
478 
479   return collapsedSize;
480 }
481 
482 /// Compute the collapsed stride of the given \p collpaseShape for the
483 /// \p groupId-th reassociation group.
484 /// \p origStrides and \p origSizes hold respectively the strides and sizes
485 /// of the source shape as values.
486 /// This is used to compute the strides in cases of dynamic shapes and/or
487 /// dynamic stride for this reassociation group.
488 ///
489 /// Conceptually this helper function returns the stride of the inner most
490 /// dimension of that group in the original shape.
491 ///
492 /// \post result.size() == 1, in other words, each group collapse to one
493 /// dimension.
494 static SmallVector<OpFoldResult>
495 getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
496                    ArrayRef<OpFoldResult> origSizes,
497                    ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
498   SmallVector<int64_t, 2> reassocGroup =
499       collapseShape.getReassociationIndices()[groupId];
500   assert(!reassocGroup.empty() &&
501          "Reassociation group should have at least one dimension");
502 
503   Value source = collapseShape.getSrc();
504   auto sourceType = cast<MemRefType>(source.getType());
505 
506   auto [strides, offset] = sourceType.getStridesAndOffset();
507 
508   SmallVector<OpFoldResult> groupStrides;
509   ArrayRef<int64_t> srcShape = sourceType.getShape();
510 
511   OpFoldResult lastValidStride = nullptr;
512   for (int64_t currentDim : reassocGroup) {
513     // Skip size-of-1 dimensions, since right now their strides may be
514     // meaningless.
515     // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless
516     // they are truly contiguous. When they are truly contiguous, we shouldn't
517     // need to skip them.
518     if (srcShape[currentDim] == 1)
519       continue;
520 
521     int64_t currentStride = strides[currentDim];
522     lastValidStride = ShapedType::isDynamic(currentStride)
523                           ? origStrides[currentDim]
524                           : builder.getIndexAttr(currentStride);
525   }
526   if (!lastValidStride) {
527     // We're dealing with a 1x1x...x1 shape. The stride is meaningless,
528     // but we still have to make the type system happy.
529     MemRefType collapsedType = collapseShape.getResultType();
530     auto [collapsedStrides, collapsedOffset] =
531         collapsedType.getStridesAndOffset();
532     int64_t finalStride = collapsedStrides[groupId];
533     if (ShapedType::isDynamic(finalStride)) {
534       // Look for a dynamic stride. At this point we don't know which one is
535       // desired, but they are all equally good/bad.
536       for (int64_t currentDim : reassocGroup) {
537         assert(srcShape[currentDim] == 1 &&
538                "We should be dealing with 1x1x...x1");
539 
540         if (ShapedType::isDynamic(strides[currentDim]))
541           return {origStrides[currentDim]};
542       }
543       llvm_unreachable("We should have found a dynamic stride");
544     }
545     return {builder.getIndexAttr(finalStride)};
546   }
547 
548   return {lastValidStride};
549 }
550 
551 /// From `reshape_like(memref, subSizes, subStrides))` compute
552 ///
553 /// \verbatim
554 /// baseBuffer, baseOffset, baseSizes, baseStrides =
555 ///     extract_strided_metadata(memref)
556 /// strides#i = baseStrides#i * subStrides#i
557 /// sizes = subSizes
558 /// \endverbatim
559 ///
560 /// and return {baseBuffer, baseOffset, sizes, strides}
561 template <typename ReassociativeReshapeLikeOp>
562 static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
563     RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
564     function_ref<SmallVector<OpFoldResult>(
565         ReassociativeReshapeLikeOp, OpBuilder &,
566         ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)>
567         getReshapedSizes,
568     function_ref<SmallVector<OpFoldResult>(
569         ReassociativeReshapeLikeOp, OpBuilder &,
570         ArrayRef<OpFoldResult> /*origSizes*/,
571         ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
572         getReshapedStrides) {
573   // Build a plain extract_strided_metadata(memref) from
574   // extract_strided_metadata(reassociative_reshape_like(memref)).
575   Location origLoc = reshape.getLoc();
576   Value source = reshape.getSrc();
577   auto sourceType = cast<MemRefType>(source.getType());
578   unsigned sourceRank = sourceType.getRank();
579 
580   auto newExtractStridedMetadata =
581       rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
582 
583   // Collect statically known information.
584   auto [strides, offset] = sourceType.getStridesAndOffset();
585   MemRefType reshapeType = reshape.getResultType();
586   unsigned reshapeRank = reshapeType.getRank();
587 
588   OpFoldResult offsetOfr =
589       ShapedType::isDynamic(offset)
590           ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
591           : rewriter.getIndexAttr(offset);
592 
593   // Get the special case of 0-D out of the way.
594   if (sourceRank == 0) {
595     SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
596     return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
597                            /*sizes=*/ones, /*strides=*/ones};
598   }
599 
600   SmallVector<OpFoldResult> finalSizes;
601   finalSizes.reserve(reshapeRank);
602   SmallVector<OpFoldResult> finalStrides;
603   finalStrides.reserve(reshapeRank);
604 
605   // Compute the reshaped strides and sizes from the base strides and sizes.
606   SmallVector<OpFoldResult> origSizes =
607       getAsOpFoldResult(newExtractStridedMetadata.getSizes());
608   SmallVector<OpFoldResult> origStrides =
609       getAsOpFoldResult(newExtractStridedMetadata.getStrides());
610   unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
611   for (; idx != endIdx; ++idx) {
612     SmallVector<OpFoldResult> reshapedSizes =
613         getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
614     SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
615         reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
616 
617     unsigned groupSize = reshapedSizes.size();
618     for (unsigned i = 0; i < groupSize; ++i) {
619       finalSizes.push_back(reshapedSizes[i]);
620       finalStrides.push_back(reshapedStrides[i]);
621     }
622   }
623   assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
624           (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
625          "We should have visited all the input dimensions");
626   assert(finalSizes.size() == reshapeRank &&
627          "We should have populated all the values");
628 
629   return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
630                          finalSizes, finalStrides};
631 }
632 
633 /// Replace `baseBuffer, offset, sizes, strides =
634 ///              extract_strided_metadata(reshapeLike(memref))`
635 /// With
636 ///
637 /// \verbatim
638 /// baseBuffer, offset, baseSizes, baseStrides =
639 ///     extract_strided_metadata(memref)
640 /// sizes = getReshapedSizes(reshapeLike)
641 /// strides = getReshapedStrides(reshapeLike)
642 /// \endverbatim
643 ///
644 ///
645 /// Notice that `baseBuffer` and `offset` are unchanged.
646 ///
647 /// In other words, get rid of the expand_shape in that expression and
648 /// materialize its effects on the sizes and the strides using affine apply.
649 template <typename ReassociativeReshapeLikeOp,
650           SmallVector<OpFoldResult> (*getReshapedSizes)(
651               ReassociativeReshapeLikeOp, OpBuilder &,
652               ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
653           SmallVector<OpFoldResult> (*getReshapedStrides)(
654               ReassociativeReshapeLikeOp, OpBuilder &,
655               ArrayRef<OpFoldResult> /*origSizes*/,
656               ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
657 struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
658 public:
659   using OpRewritePattern<ReassociativeReshapeLikeOp>::OpRewritePattern;
660 
661   LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
662                                 PatternRewriter &rewriter) const override {
663     FailureOr<StridedMetadata> stridedMetadata =
664         resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
665             rewriter, reshape, getReshapedSizes, getReshapedStrides);
666     if (failed(stridedMetadata)) {
667       return rewriter.notifyMatchFailure(reshape,
668                                          "failed to resolve reshape metadata");
669     }
670 
671     rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
672         reshape, reshape.getType(), stridedMetadata->basePtr,
673         stridedMetadata->offset, stridedMetadata->sizes,
674         stridedMetadata->strides);
675     return success();
676   }
677 };
678 
679 /// Pattern to replace `extract_strided_metadata(collapse_shape)`
680 /// With
681 ///
682 /// \verbatim
683 /// baseBuffer, baseOffset, baseSizes, baseStrides =
684 ///     extract_strided_metadata(memref)
685 /// strides#i = baseStrides#i * subSizes#i
686 /// offset = baseOffset + sum(subOffset#i * baseStrides#i)
687 /// sizes = subSizes
688 /// \verbatim
689 ///
690 /// with `baseBuffer`, `offset`, `sizes` and `strides` being
691 /// the replacements for the original `extract_strided_metadata`.
692 struct ExtractStridedMetadataOpCollapseShapeFolder
693     : OpRewritePattern<memref::ExtractStridedMetadataOp> {
694   using OpRewritePattern::OpRewritePattern;
695 
696   LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
697                                 PatternRewriter &rewriter) const override {
698     auto collapseShapeOp =
699         op.getSource().getDefiningOp<memref::CollapseShapeOp>();
700     if (!collapseShapeOp)
701       return failure();
702 
703     FailureOr<StridedMetadata> stridedMetadata =
704         resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
705             rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
706     if (failed(stridedMetadata)) {
707       return rewriter.notifyMatchFailure(
708           op,
709           "failed to resolve metadata in terms of source collapse_shape op");
710     }
711 
712     Location loc = collapseShapeOp.getLoc();
713     SmallVector<Value> results;
714     results.push_back(stridedMetadata->basePtr);
715     results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
716                                                       stridedMetadata->offset));
717     results.append(
718         getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
719     results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
720                                                    stridedMetadata->strides));
721     rewriter.replaceOp(op, results);
722     return success();
723   }
724 };
725 
726 /// Pattern to replace `extract_strided_metadata(expand_shape)`
727 /// with the results of computing the sizes and strides on the expanded shape
728 /// and dividing up dimensions into static and dynamic parts as needed.
729 struct ExtractStridedMetadataOpExpandShapeFolder
730     : OpRewritePattern<memref::ExtractStridedMetadataOp> {
731   using OpRewritePattern::OpRewritePattern;
732 
733   LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
734                                 PatternRewriter &rewriter) const override {
735     auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
736     if (!expandShapeOp)
737       return failure();
738 
739     FailureOr<StridedMetadata> stridedMetadata =
740         resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
741             rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
742     if (failed(stridedMetadata)) {
743       return rewriter.notifyMatchFailure(
744           op, "failed to resolve metadata in terms of source expand_shape op");
745     }
746 
747     Location loc = expandShapeOp.getLoc();
748     SmallVector<Value> results;
749     results.push_back(stridedMetadata->basePtr);
750     results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
751                                                       stridedMetadata->offset));
752     results.append(
753         getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
754     results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
755                                                    stridedMetadata->strides));
756     rewriter.replaceOp(op, results);
757     return success();
758   }
759 };
760 
761 /// Replace `base, offset, sizes, strides =
762 ///              extract_strided_metadata(allocLikeOp)`
763 ///
764 /// With
765 ///
766 /// ```
767 /// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy>
768 /// offset = 0
769 /// sizes = allocSizes
770 /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
771 /// ```
772 ///
773 /// The transformation only applies if the allocLikeOp has been normalized.
774 /// In other words, the affine_map must be an identity.
775 template <typename AllocLikeOp>
776 struct ExtractStridedMetadataOpAllocFolder
777     : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
778 public:
779   using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
780 
781   LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
782                                 PatternRewriter &rewriter) const override {
783     auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
784     if (!allocLikeOp)
785       return failure();
786 
787     auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
788     if (!memRefType.getLayout().isIdentity())
789       return rewriter.notifyMatchFailure(
790           allocLikeOp, "alloc-like operations should have been normalized");
791 
792     Location loc = op.getLoc();
793     int rank = memRefType.getRank();
794 
795     // Collect the sizes.
796     ValueRange dynamic = allocLikeOp.getDynamicSizes();
797     SmallVector<OpFoldResult> sizes;
798     sizes.reserve(rank);
799     unsigned dynamicPos = 0;
800     for (int64_t size : memRefType.getShape()) {
801       if (ShapedType::isDynamic(size))
802         sizes.push_back(dynamic[dynamicPos++]);
803       else
804         sizes.push_back(rewriter.getIndexAttr(size));
805     }
806 
807     // Strides (just creates identity strides).
808     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
809     AffineExpr expr = rewriter.getAffineConstantExpr(1);
810     unsigned symbolNumber = 0;
811     for (int i = rank - 2; i >= 0; --i) {
812       expr = expr * rewriter.getAffineSymbolExpr(symbolNumber++);
813       assert(i + 1 + symbolNumber == sizes.size() &&
814              "The ArrayRef should encompass the last #symbolNumber sizes");
815       ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber);
816       strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr,
817                                                  sizesInvolvedInStride);
818     }
819 
820     // Put all the values together to replace the results.
821     SmallVector<Value> results;
822     results.reserve(rank * 2 + 2);
823 
824     auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
825     int64_t offset = 0;
826     if (op.getBaseBuffer().use_empty()) {
827       results.push_back(nullptr);
828     } else {
829       if (allocLikeOp.getType() == baseBufferType)
830         results.push_back(allocLikeOp);
831       else
832         results.push_back(rewriter.create<memref::ReinterpretCastOp>(
833             loc, baseBufferType, allocLikeOp, offset,
834             /*sizes=*/ArrayRef<int64_t>(),
835             /*strides=*/ArrayRef<int64_t>()));
836     }
837 
838     // Offset.
839     results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
840 
841     for (OpFoldResult size : sizes)
842       results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size));
843 
844     for (OpFoldResult stride : strides)
845       results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride));
846 
847     rewriter.replaceOp(op, results);
848     return success();
849   }
850 };
851 
852 /// Replace `base, offset, sizes, strides =
853 ///              extract_strided_metadata(get_global)`
854 ///
855 /// With
856 ///
857 /// ```
858 /// base = reinterpret_cast get_global to a flat memref<eltTy>
859 /// offset = 0
860 /// sizes = allocSizes
861 /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
862 /// ```
863 ///
864 /// It is expected that the memref.get_global op has static shapes
865 /// and identity affine_map for the layout.
866 struct ExtractStridedMetadataOpGetGlobalFolder
867     : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
868 public:
869   using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
870 
871   LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
872                                 PatternRewriter &rewriter) const override {
873     auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
874     if (!getGlobalOp)
875       return failure();
876 
877     auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
878     if (!memRefType.getLayout().isIdentity()) {
879       return rewriter.notifyMatchFailure(
880           getGlobalOp,
881           "get-global operation result should have been normalized");
882     }
883 
884     Location loc = op.getLoc();
885     int rank = memRefType.getRank();
886 
887     // Collect the sizes.
888     ArrayRef<int64_t> sizes = memRefType.getShape();
889     assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
890            "unexpected dynamic shape for result of `memref.get_global` op");
891 
892     // Strides (just creates identity strides).
893     SmallVector<int64_t> strides = computeSuffixProduct(sizes);
894 
895     // Put all the values together to replace the results.
896     SmallVector<Value> results;
897     results.reserve(rank * 2 + 2);
898 
899     auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
900     int64_t offset = 0;
901     if (getGlobalOp.getType() == baseBufferType)
902       results.push_back(getGlobalOp);
903     else
904       results.push_back(rewriter.create<memref::ReinterpretCastOp>(
905           loc, baseBufferType, getGlobalOp, offset,
906           /*sizes=*/ArrayRef<int64_t>(),
907           /*strides=*/ArrayRef<int64_t>()));
908 
909     // Offset.
910     results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
911 
912     for (auto size : sizes)
913       results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size));
914 
915     for (auto stride : strides)
916       results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride));
917 
918     rewriter.replaceOp(op, results);
919     return success();
920   }
921 };
922 
923 /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
924 /// source of the ViewLikeOp.
925 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
926     : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> {
927   using OpRewritePattern::OpRewritePattern;
928 
929   LogicalResult
930   matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
931                   PatternRewriter &rewriter) const override {
932     auto viewLikeOp =
933         extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
934     if (!viewLikeOp)
935       return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
936     rewriter.modifyOpInPlace(extractOp, [&]() {
937       extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
938     });
939     return success();
940   }
941 };
942 
943 /// Replace `base, offset, sizes, strides =
944 ///              extract_strided_metadata(
945 ///                 reinterpret_cast(src, srcOffset, srcSizes, srcStrides))`
946 /// With
947 /// ```
948 /// base, ... = extract_strided_metadata(src)
949 /// offset = srcOffset
950 /// sizes = srcSizes
951 /// strides = srcStrides
952 /// ```
953 ///
954 /// In other words, consume the `reinterpret_cast` and apply its effects
955 /// on the offset, sizes, and strides.
956 class ExtractStridedMetadataOpReinterpretCastFolder
957     : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
958   using OpRewritePattern::OpRewritePattern;
959 
960   LogicalResult
961   matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
962                   PatternRewriter &rewriter) const override {
963     auto reinterpretCastOp = extractStridedMetadataOp.getSource()
964                                  .getDefiningOp<memref::ReinterpretCastOp>();
965     if (!reinterpretCastOp)
966       return failure();
967 
968     Location loc = extractStridedMetadataOp.getLoc();
969     // Check if the source is suitable for extract_strided_metadata.
970     SmallVector<Type> inferredReturnTypes;
971     if (failed(extractStridedMetadataOp.inferReturnTypes(
972             rewriter.getContext(), loc, {reinterpretCastOp.getSource()},
973             /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
974             inferredReturnTypes)))
975       return rewriter.notifyMatchFailure(
976           reinterpretCastOp, "reinterpret_cast source's type is incompatible");
977 
978     auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
979     unsigned rank = memrefType.getRank();
980     SmallVector<OpFoldResult> results;
981     results.resize_for_overwrite(rank * 2 + 2);
982 
983     auto newExtractStridedMetadata =
984         rewriter.create<memref::ExtractStridedMetadataOp>(
985             loc, reinterpretCastOp.getSource());
986 
987     // Register the base_buffer.
988     results[0] = newExtractStridedMetadata.getBaseBuffer();
989 
990     // Register the new offset.
991     results[1] = getValueOrCreateConstantIndexOp(
992         rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
993 
994     const unsigned sizeStartIdx = 2;
995     const unsigned strideStartIdx = sizeStartIdx + rank;
996 
997     SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
998     SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
999     for (unsigned i = 0; i < rank; ++i) {
1000       results[sizeStartIdx + i] = sizes[i];
1001       results[strideStartIdx + i] = strides[i];
1002     }
1003     rewriter.replaceOp(extractStridedMetadataOp,
1004                        getValueOrCreateConstantIndexOp(rewriter, loc, results));
1005     return success();
1006   }
1007 };
1008 
1009 /// Replace `base, offset, sizes, strides =
1010 ///              extract_strided_metadata(
1011 ///                 cast(src) to dstTy)`
1012 /// With
1013 /// ```
1014 /// base, ... = extract_strided_metadata(src)
1015 /// offset = !dstTy.srcOffset.isDynamic()
1016 ///            ? dstTy.srcOffset
1017 ///            : extract_strided_metadata(src).offset
1018 /// sizes = for each srcSize in dstTy.srcSizes:
1019 ///           !srcSize.isDynamic()
1020 ///             ? srcSize
1021 //              : extract_strided_metadata(src).sizes[i]
1022 /// strides = for each srcStride in dstTy.srcStrides:
1023 ///             !srcStrides.isDynamic()
1024 ///               ? srcStrides
1025 ///               : extract_strided_metadata(src).strides[i]
1026 /// ```
1027 ///
1028 /// In other words, consume the `cast` and apply its effects
1029 /// on the offset, sizes, and strides or compute them directly from `src`.
1030 class ExtractStridedMetadataOpCastFolder
1031     : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1032   using OpRewritePattern::OpRewritePattern;
1033 
1034   LogicalResult
1035   matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1036                   PatternRewriter &rewriter) const override {
1037     Value source = extractStridedMetadataOp.getSource();
1038     auto castOp = source.getDefiningOp<memref::CastOp>();
1039     if (!castOp)
1040       return failure();
1041 
1042     Location loc = extractStridedMetadataOp.getLoc();
1043     // Check if the source is suitable for extract_strided_metadata.
1044     SmallVector<Type> inferredReturnTypes;
1045     if (failed(extractStridedMetadataOp.inferReturnTypes(
1046             rewriter.getContext(), loc, {castOp.getSource()},
1047             /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
1048             inferredReturnTypes)))
1049       return rewriter.notifyMatchFailure(castOp,
1050                                          "cast source's type is incompatible");
1051 
1052     auto memrefType = cast<MemRefType>(source.getType());
1053     unsigned rank = memrefType.getRank();
1054     SmallVector<OpFoldResult> results;
1055     results.resize_for_overwrite(rank * 2 + 2);
1056 
1057     auto newExtractStridedMetadata =
1058         rewriter.create<memref::ExtractStridedMetadataOp>(loc,
1059                                                           castOp.getSource());
1060 
1061     // Register the base_buffer.
1062     results[0] = newExtractStridedMetadata.getBaseBuffer();
1063 
1064     auto getConstantOrValue = [&rewriter](int64_t constant,
1065                                           OpFoldResult ofr) -> OpFoldResult {
1066       return !ShapedType::isDynamic(constant)
1067                  ? OpFoldResult(rewriter.getIndexAttr(constant))
1068                  : ofr;
1069     };
1070 
1071     auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
1072     assert(sourceStrides.size() == rank && "unexpected number of strides");
1073 
1074     // Register the new offset.
1075     results[1] =
1076         getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
1077 
1078     const unsigned sizeStartIdx = 2;
1079     const unsigned strideStartIdx = sizeStartIdx + rank;
1080     ArrayRef<int64_t> sourceSizes = memrefType.getShape();
1081 
1082     SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
1083     SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
1084     for (unsigned i = 0; i < rank; ++i) {
1085       results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
1086       results[strideStartIdx + i] =
1087           getConstantOrValue(sourceStrides[i], strides[i]);
1088     }
1089     rewriter.replaceOp(extractStridedMetadataOp,
1090                        getValueOrCreateConstantIndexOp(rewriter, loc, results));
1091     return success();
1092   }
1093 };
1094 
1095 /// Replace `base, offset, sizes, strides = extract_strided_metadata(
1096 ///      memory_space_cast(src) to dstTy)`
1097 /// with
1098 /// ```
1099 ///    oldBase, offset, sizes, strides = extract_strided_metadata(src)
1100 ///    destBaseTy = type(oldBase) with memory space from destTy
1101 ///    base = memory_space_cast(oldBase) to destBaseTy
1102 /// ```
1103 ///
1104 /// In other words, propagate metadata extraction accross memory space casts.
1105 class ExtractStridedMetadataOpMemorySpaceCastFolder
1106     : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1107   using OpRewritePattern::OpRewritePattern;
1108 
1109   LogicalResult
1110   matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1111                   PatternRewriter &rewriter) const override {
1112     Location loc = extractStridedMetadataOp.getLoc();
1113     Value source = extractStridedMetadataOp.getSource();
1114     auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
1115     if (!memSpaceCastOp)
1116       return failure();
1117     auto newExtractStridedMetadata =
1118         rewriter.create<memref::ExtractStridedMetadataOp>(
1119             loc, memSpaceCastOp.getSource());
1120     SmallVector<Value> results(newExtractStridedMetadata.getResults());
1121     // As with most other strided metadata rewrite patterns, don't introduce
1122     // a use of the base pointer where non existed. This needs to happen here,
1123     // as opposed to in later dead-code elimination, because these patterns are
1124     // sometimes used during dialect conversion (see EmulateNarrowType, for
1125     // example), so adding spurious usages would cause a pre-legalization value
1126     // to be live that would be dead had this pattern not run.
1127     if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1128       auto baseBuffer = results[0];
1129       auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1130       MemRefType::Builder newTypeBuilder(baseBufferType);
1131       newTypeBuilder.setMemorySpace(
1132           memSpaceCastOp.getResult().getType().getMemorySpace());
1133       results[0] = rewriter.create<memref::MemorySpaceCastOp>(
1134           loc, Type{newTypeBuilder}, baseBuffer);
1135     } else {
1136       results[0] = nullptr;
1137     }
1138     rewriter.replaceOp(extractStridedMetadataOp, results);
1139     return success();
1140   }
1141 };
1142 
1143 /// Replace `base, offset =
1144 ///            extract_strided_metadata(extract_strided_metadata(src)#0)`
1145 /// With
1146 /// ```
1147 /// base, ... = extract_strided_metadata(src)
1148 /// offset = 0
1149 /// ```
1150 class ExtractStridedMetadataOpExtractStridedMetadataFolder
1151     : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1152   using OpRewritePattern::OpRewritePattern;
1153 
1154   LogicalResult
1155   matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1156                   PatternRewriter &rewriter) const override {
1157     auto sourceExtractStridedMetadataOp =
1158         extractStridedMetadataOp.getSource()
1159             .getDefiningOp<memref::ExtractStridedMetadataOp>();
1160     if (!sourceExtractStridedMetadataOp)
1161       return failure();
1162     Location loc = extractStridedMetadataOp.getLoc();
1163     rewriter.replaceOp(extractStridedMetadataOp,
1164                        {sourceExtractStridedMetadataOp.getBaseBuffer(),
1165                         getValueOrCreateConstantIndexOp(
1166                             rewriter, loc, rewriter.getIndexAttr(0))});
1167     return success();
1168   }
1169 };
1170 } // namespace
1171 
1172 void memref::populateExpandStridedMetadataPatterns(
1173     RewritePatternSet &patterns) {
1174   patterns.add<SubviewFolder,
1175                ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1176                              getExpandedStrides>,
1177                ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1178                              getCollapsedStride>,
1179                ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1180                ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1181                ExtractStridedMetadataOpCollapseShapeFolder,
1182                ExtractStridedMetadataOpExpandShapeFolder,
1183                ExtractStridedMetadataOpGetGlobalFolder,
1184                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1185                ExtractStridedMetadataOpReinterpretCastFolder,
1186                ExtractStridedMetadataOpSubviewFolder,
1187                ExtractStridedMetadataOpCastFolder,
1188                ExtractStridedMetadataOpMemorySpaceCastFolder,
1189                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1190       patterns.getContext());
1191 }
1192 
1193 void memref::populateResolveExtractStridedMetadataPatterns(
1194     RewritePatternSet &patterns) {
1195   patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1196                ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1197                ExtractStridedMetadataOpCollapseShapeFolder,
1198                ExtractStridedMetadataOpExpandShapeFolder,
1199                ExtractStridedMetadataOpGetGlobalFolder,
1200                ExtractStridedMetadataOpSubviewFolder,
1201                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1202                ExtractStridedMetadataOpReinterpretCastFolder,
1203                ExtractStridedMetadataOpCastFolder,
1204                ExtractStridedMetadataOpMemorySpaceCastFolder,
1205                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1206       patterns.getContext());
1207 }
1208 
1209 //===----------------------------------------------------------------------===//
1210 // Pass registration
1211 //===----------------------------------------------------------------------===//
1212 
1213 namespace {
1214 
1215 struct ExpandStridedMetadataPass final
1216     : public memref::impl::ExpandStridedMetadataBase<
1217           ExpandStridedMetadataPass> {
1218   void runOnOperation() override;
1219 };
1220 
1221 } // namespace
1222 
1223 void ExpandStridedMetadataPass::runOnOperation() {
1224   RewritePatternSet patterns(&getContext());
1225   memref::populateExpandStridedMetadataPatterns(patterns);
1226   (void)applyPatternsGreedily(getOperation(), std::move(patterns));
1227 }
1228 
1229 std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() {
1230   return std::make_unique<ExpandStridedMetadataPass>();
1231 }
1232