xref: /llvm-project/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arith/IR/Arith.h"
10 #include "mlir/Dialect/Arith/Utils/Utils.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/Utils/StaticValueUtils.h"
13 #include "mlir/IR/AffineMap.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Interfaces/InferTypeOpInterface.h"
21 #include "mlir/Interfaces/SideEffectInterfaces.h"
22 #include "mlir/Interfaces/ViewLikeInterface.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 
26 using namespace mlir;
27 using namespace mlir::memref;
28 
29 /// Materialize a single constant operation from a given attribute value with
30 /// the desired resultant type.
31 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
32                                               Attribute value, Type type,
33                                               Location loc) {
34   return arith::ConstantOp::materialize(builder, value, type, loc);
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Common canonicalization pattern support logic
39 //===----------------------------------------------------------------------===//
40 
41 /// This is a common class used for patterns of the form
42 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
43 /// into the root operation directly.
44 LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
45   bool folded = false;
46   for (OpOperand &operand : op->getOpOperands()) {
47     auto cast = operand.get().getDefiningOp<CastOp>();
48     if (cast && operand.get() != inner &&
49         !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
50       operand.set(cast.getOperand());
51       folded = true;
52     }
53   }
54   return success(folded);
55 }
56 
57 /// Return an unranked/ranked tensor type for the given unranked/ranked memref
58 /// type.
59 Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
60   if (auto memref = llvm::dyn_cast<MemRefType>(type))
61     return RankedTensorType::get(memref.getShape(), memref.getElementType());
62   if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
63     return UnrankedTensorType::get(memref.getElementType());
64   return NoneType::get(type.getContext());
65 }
66 
67 OpFoldResult memref::getMixedSize(OpBuilder &builder, Location loc, Value value,
68                                   int64_t dim) {
69   auto memrefType = llvm::cast<MemRefType>(value.getType());
70   SmallVector<OpFoldResult> result;
71   if (memrefType.isDynamicDim(dim))
72     return builder.createOrFold<memref::DimOp>(loc, value, dim);
73 
74   return builder.getIndexAttr(memrefType.getDimSize(dim));
75 }
76 
77 SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
78                                                 Location loc, Value value) {
79   auto memrefType = llvm::cast<MemRefType>(value.getType());
80   SmallVector<OpFoldResult> result;
81   for (int64_t i = 0; i < memrefType.getRank(); ++i)
82     result.push_back(getMixedSize(builder, loc, value, i));
83   return result;
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Utility functions for propagating static information
88 //===----------------------------------------------------------------------===//
89 
90 /// Helper function that infers the constant values from a list of \p values,
91 /// a \p memRefTy, and another helper function \p getAttributes.
92 /// The inferred constant values replace the related `OpFoldResult` in
93 /// \p values.
94 ///
95 /// \note This function shouldn't be used directly, instead, use the
96 /// `getConstifiedMixedXXX` methods from the related operations.
97 ///
98 /// \p getAttributes retuns a list of potentially constant values, as determined
99 /// by \p isDynamic, from the given \p memRefTy. The returned list must have as
100 /// many elements as \p values or be empty.
101 ///
102 /// E.g., consider the following example:
103 /// ```
104 /// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
105 ///     memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
106 /// ```
107 /// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
108 /// Now using this helper function with:
109 /// - `values == [2, %dyn_stride]`,
110 /// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
111 /// - `getAttributes == getConstantStrides` (i.e., a wrapper around
112 /// `getStridesAndOffset`), and
113 /// - `isDynamic == ShapedType::isDynamic`
114 /// Will yield: `values == [2, 1]`
115 static void constifyIndexValues(
116     SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
117     MLIRContext *ctxt,
118     llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
119     llvm::function_ref<bool(int64_t)> isDynamic) {
120   SmallVector<int64_t> constValues = getAttributes(memRefTy);
121   Builder builder(ctxt);
122   for (const auto &it : llvm::enumerate(constValues)) {
123     int64_t constValue = it.value();
124     if (!isDynamic(constValue))
125       values[it.index()] = builder.getIndexAttr(constValue);
126   }
127   for (OpFoldResult &ofr : values) {
128     if (auto attr = dyn_cast<Attribute>(ofr)) {
129       // FIXME: We shouldn't need to do that, but right now, the static indices
130       // are created with the wrong type: `i64` instead of `index`.
131       // As a result, if we were to keep the attribute as is, we may fail to see
132       // that two attributes are equal because one would have the i64 type and
133       // the other the index type.
134       // The alternative would be to create constant indices with getI64Attr in
135       // this and the previous loop, but it doesn't logically make sense (we are
136       // dealing with indices here) and would only strenghten the inconsistency
137       // around how static indices are created (some places use getI64Attr,
138       // others use getIndexAttr).
139       // The workaround here is to stick to the IndexAttr type for all the
140       // values, hence we recreate the attribute even when it is already static
141       // to make sure the type is consistent.
142       ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
143       continue;
144     }
145     std::optional<int64_t> maybeConstant =
146         getConstantIntValue(cast<Value>(ofr));
147     if (maybeConstant)
148       ofr = builder.getIndexAttr(*maybeConstant);
149   }
150 }
151 
152 /// Wrapper around `getShape` that conforms to the function signature
153 /// expected for `getAttributes` in `constifyIndexValues`.
154 static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
155   ArrayRef<int64_t> sizes = memRefTy.getShape();
156   return SmallVector<int64_t>(sizes);
157 }
158 
159 /// Wrapper around `getStridesAndOffset` that returns only the offset and
160 /// conforms to the function signature expected for `getAttributes` in
161 /// `constifyIndexValues`.
162 static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
163   SmallVector<int64_t> strides;
164   int64_t offset;
165   LogicalResult hasStaticInformation =
166       memrefType.getStridesAndOffset(strides, offset);
167   if (failed(hasStaticInformation))
168     return SmallVector<int64_t>();
169   return SmallVector<int64_t>(1, offset);
170 }
171 
172 /// Wrapper around `getStridesAndOffset` that returns only the strides and
173 /// conforms to the function signature expected for `getAttributes` in
174 /// `constifyIndexValues`.
175 static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
176   SmallVector<int64_t> strides;
177   int64_t offset;
178   LogicalResult hasStaticInformation =
179       memrefType.getStridesAndOffset(strides, offset);
180   if (failed(hasStaticInformation))
181     return SmallVector<int64_t>();
182   return strides;
183 }
184 
185 //===----------------------------------------------------------------------===//
186 // AllocOp / AllocaOp
187 //===----------------------------------------------------------------------===//
188 
189 void AllocOp::getAsmResultNames(
190     function_ref<void(Value, StringRef)> setNameFn) {
191   setNameFn(getResult(), "alloc");
192 }
193 
194 void AllocaOp::getAsmResultNames(
195     function_ref<void(Value, StringRef)> setNameFn) {
196   setNameFn(getResult(), "alloca");
197 }
198 
199 template <typename AllocLikeOp>
200 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
201   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
202                 "applies to only alloc or alloca");
203   auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
204   if (!memRefType)
205     return op.emitOpError("result must be a memref");
206 
207   if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
208     return op.emitOpError("dimension operand count does not equal memref "
209                           "dynamic dimension count");
210 
211   unsigned numSymbols = 0;
212   if (!memRefType.getLayout().isIdentity())
213     numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
214   if (op.getSymbolOperands().size() != numSymbols)
215     return op.emitOpError("symbol operand count does not equal memref symbol "
216                           "count: expected ")
217            << numSymbols << ", got " << op.getSymbolOperands().size();
218 
219   return success();
220 }
221 
222 LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
223 
224 LogicalResult AllocaOp::verify() {
225   // An alloca op needs to have an ancestor with an allocation scope trait.
226   if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
227     return emitOpError(
228         "requires an ancestor op with AutomaticAllocationScope trait");
229 
230   return verifyAllocLikeOp(*this);
231 }
232 
233 namespace {
234 /// Fold constant dimensions into an alloc like operation.
235 template <typename AllocLikeOp>
236 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
237   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
238 
239   LogicalResult matchAndRewrite(AllocLikeOp alloc,
240                                 PatternRewriter &rewriter) const override {
241     // Check to see if any dimensions operands are constants.  If so, we can
242     // substitute and drop them.
243     if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
244           APInt constSizeArg;
245           if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
246             return false;
247           return constSizeArg.isNonNegative();
248         }))
249       return failure();
250 
251     auto memrefType = alloc.getType();
252 
253     // Ok, we have one or more constant operands.  Collect the non-constant ones
254     // and keep track of the resultant memref type to build.
255     SmallVector<int64_t, 4> newShapeConstants;
256     newShapeConstants.reserve(memrefType.getRank());
257     SmallVector<Value, 4> dynamicSizes;
258 
259     unsigned dynamicDimPos = 0;
260     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
261       int64_t dimSize = memrefType.getDimSize(dim);
262       // If this is already static dimension, keep it.
263       if (!ShapedType::isDynamic(dimSize)) {
264         newShapeConstants.push_back(dimSize);
265         continue;
266       }
267       auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
268       APInt constSizeArg;
269       if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
270           constSizeArg.isNonNegative()) {
271         // Dynamic shape dimension will be folded.
272         newShapeConstants.push_back(constSizeArg.getZExtValue());
273       } else {
274         // Dynamic shape dimension not folded; copy dynamicSize from old memref.
275         newShapeConstants.push_back(ShapedType::kDynamic);
276         dynamicSizes.push_back(dynamicSize);
277       }
278       dynamicDimPos++;
279     }
280 
281     // Create new memref type (which will have fewer dynamic dimensions).
282     MemRefType newMemRefType =
283         MemRefType::Builder(memrefType).setShape(newShapeConstants);
284     assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
285 
286     // Create and insert the alloc op for the new memref.
287     auto newAlloc = rewriter.create<AllocLikeOp>(
288         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
289         alloc.getAlignmentAttr());
290     // Insert a cast so we have the same type as the old alloc.
291     rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
292     return success();
293   }
294 };
295 
296 /// Fold alloc operations with no users or only store and dealloc uses.
297 template <typename T>
298 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
299   using OpRewritePattern<T>::OpRewritePattern;
300 
301   LogicalResult matchAndRewrite(T alloc,
302                                 PatternRewriter &rewriter) const override {
303     if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
304           if (auto storeOp = dyn_cast<StoreOp>(op))
305             return storeOp.getValue() == alloc;
306           return !isa<DeallocOp>(op);
307         }))
308       return failure();
309 
310     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
311       rewriter.eraseOp(user);
312 
313     rewriter.eraseOp(alloc);
314     return success();
315   }
316 };
317 } // namespace
318 
319 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
320                                           MLIRContext *context) {
321   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
322 }
323 
324 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
325                                            MLIRContext *context) {
326   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
327       context);
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // ReallocOp
332 //===----------------------------------------------------------------------===//
333 
334 LogicalResult ReallocOp::verify() {
335   auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
336   MemRefType resultType = getType();
337 
338   // The source memref should have identity layout (or none).
339   if (!sourceType.getLayout().isIdentity())
340     return emitError("unsupported layout for source memref type ")
341            << sourceType;
342 
343   // The result memref should have identity layout (or none).
344   if (!resultType.getLayout().isIdentity())
345     return emitError("unsupported layout for result memref type ")
346            << resultType;
347 
348   // The source memref and the result memref should be in the same memory space.
349   if (sourceType.getMemorySpace() != resultType.getMemorySpace())
350     return emitError("different memory spaces specified for source memref "
351                      "type ")
352            << sourceType << " and result memref type " << resultType;
353 
354   // The source memref and the result memref should have the same element type.
355   if (sourceType.getElementType() != resultType.getElementType())
356     return emitError("different element types specified for source memref "
357                      "type ")
358            << sourceType << " and result memref type " << resultType;
359 
360   // Verify that we have the dynamic dimension operand when it is needed.
361   if (resultType.getNumDynamicDims() && !getDynamicResultSize())
362     return emitError("missing dimension operand for result type ")
363            << resultType;
364   if (!resultType.getNumDynamicDims() && getDynamicResultSize())
365     return emitError("unnecessary dimension operand for result type ")
366            << resultType;
367 
368   return success();
369 }
370 
371 void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
372                                             MLIRContext *context) {
373   results.add<SimplifyDeadAlloc<ReallocOp>>(context);
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // AllocaScopeOp
378 //===----------------------------------------------------------------------===//
379 
380 void AllocaScopeOp::print(OpAsmPrinter &p) {
381   bool printBlockTerminators = false;
382 
383   p << ' ';
384   if (!getResults().empty()) {
385     p << " -> (" << getResultTypes() << ")";
386     printBlockTerminators = true;
387   }
388   p << ' ';
389   p.printRegion(getBodyRegion(),
390                 /*printEntryBlockArgs=*/false,
391                 /*printBlockTerminators=*/printBlockTerminators);
392   p.printOptionalAttrDict((*this)->getAttrs());
393 }
394 
395 ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
396   // Create a region for the body.
397   result.regions.reserve(1);
398   Region *bodyRegion = result.addRegion();
399 
400   // Parse optional results type list.
401   if (parser.parseOptionalArrowTypeList(result.types))
402     return failure();
403 
404   // Parse the body region.
405   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
406     return failure();
407   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
408                                   result.location);
409 
410   // Parse the optional attribute list.
411   if (parser.parseOptionalAttrDict(result.attributes))
412     return failure();
413 
414   return success();
415 }
416 
417 void AllocaScopeOp::getSuccessorRegions(
418     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
419   if (!point.isParent()) {
420     regions.push_back(RegionSuccessor(getResults()));
421     return;
422   }
423 
424   regions.push_back(RegionSuccessor(&getBodyRegion()));
425 }
426 
427 /// Given an operation, return whether this op is guaranteed to
428 /// allocate an AutomaticAllocationScopeResource
429 static bool isGuaranteedAutomaticAllocation(Operation *op) {
430   MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
431   if (!interface)
432     return false;
433   for (auto res : op->getResults()) {
434     if (auto effect =
435             interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
436       if (isa<SideEffects::AutomaticAllocationScopeResource>(
437               effect->getResource()))
438         return true;
439     }
440   }
441   return false;
442 }
443 
444 /// Given an operation, return whether this op itself could
445 /// allocate an AutomaticAllocationScopeResource. Note that
446 /// this will not check whether an operation contained within
447 /// the op can allocate.
448 static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
449   // This op itself doesn't create a stack allocation,
450   // the inner allocation should be handled separately.
451   if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
452     return false;
453   MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
454   if (!interface)
455     return true;
456   for (auto res : op->getResults()) {
457     if (auto effect =
458             interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
459       if (isa<SideEffects::AutomaticAllocationScopeResource>(
460               effect->getResource()))
461         return true;
462     }
463   }
464   return false;
465 }
466 
467 /// Return whether this op is the last non terminating op
468 /// in a region. That is to say, it is in a one-block region
469 /// and is only followed by a terminator. This prevents
470 /// extending the lifetime of allocations.
471 static bool lastNonTerminatorInRegion(Operation *op) {
472   return op->getNextNode() == op->getBlock()->getTerminator() &&
473          op->getParentRegion()->getBlocks().size() == 1;
474 }
475 
476 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
477 /// or it contains no allocation.
478 struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
479   using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
480 
481   LogicalResult matchAndRewrite(AllocaScopeOp op,
482                                 PatternRewriter &rewriter) const override {
483     bool hasPotentialAlloca =
484         op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
485             if (alloc == op)
486               return WalkResult::advance();
487             if (isOpItselfPotentialAutomaticAllocation(alloc))
488               return WalkResult::interrupt();
489             if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
490               return WalkResult::skip();
491             return WalkResult::advance();
492           }).wasInterrupted();
493 
494     // If this contains no potential allocation, it is always legal to
495     // inline. Otherwise, consider two conditions:
496     if (hasPotentialAlloca) {
497       // If the parent isn't an allocation scope, or we are not the last
498       // non-terminator op in the parent, we will extend the lifetime.
499       if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
500         return failure();
501       if (!lastNonTerminatorInRegion(op))
502         return failure();
503     }
504 
505     Block *block = &op.getRegion().front();
506     Operation *terminator = block->getTerminator();
507     ValueRange results = terminator->getOperands();
508     rewriter.inlineBlockBefore(block, op);
509     rewriter.replaceOp(op, results);
510     rewriter.eraseOp(terminator);
511     return success();
512   }
513 };
514 
515 /// Move allocations into an allocation scope, if it is legal to
516 /// move them (e.g. their operands are available at the location
517 /// the op would be moved to).
518 struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
519   using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
520 
521   LogicalResult matchAndRewrite(AllocaScopeOp op,
522                                 PatternRewriter &rewriter) const override {
523 
524     if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
525       return failure();
526 
527     Operation *lastParentWithoutScope = op->getParentOp();
528 
529     if (!lastParentWithoutScope ||
530         lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
531       return failure();
532 
533     // Only apply to if this is this last non-terminator
534     // op in the block (lest lifetime be extended) of a one
535     // block region
536     if (!lastNonTerminatorInRegion(op) ||
537         !lastNonTerminatorInRegion(lastParentWithoutScope))
538       return failure();
539 
540     while (!lastParentWithoutScope->getParentOp()
541                 ->hasTrait<OpTrait::AutomaticAllocationScope>()) {
542       lastParentWithoutScope = lastParentWithoutScope->getParentOp();
543       if (!lastParentWithoutScope ||
544           !lastNonTerminatorInRegion(lastParentWithoutScope))
545         return failure();
546     }
547     assert(lastParentWithoutScope->getParentOp()
548                ->hasTrait<OpTrait::AutomaticAllocationScope>());
549 
550     Region *containingRegion = nullptr;
551     for (auto &r : lastParentWithoutScope->getRegions()) {
552       if (r.isAncestor(op->getParentRegion())) {
553         assert(containingRegion == nullptr &&
554                "only one region can contain the op");
555         containingRegion = &r;
556       }
557     }
558     assert(containingRegion && "op must be contained in a region");
559 
560     SmallVector<Operation *> toHoist;
561     op->walk([&](Operation *alloc) {
562       if (!isGuaranteedAutomaticAllocation(alloc))
563         return WalkResult::skip();
564 
565       // If any operand is not defined before the location of
566       // lastParentWithoutScope (i.e. where we would hoist to), skip.
567       if (llvm::any_of(alloc->getOperands(), [&](Value v) {
568             return containingRegion->isAncestor(v.getParentRegion());
569           }))
570         return WalkResult::skip();
571       toHoist.push_back(alloc);
572       return WalkResult::advance();
573     });
574 
575     if (toHoist.empty())
576       return failure();
577     rewriter.setInsertionPoint(lastParentWithoutScope);
578     for (auto *op : toHoist) {
579       auto *cloned = rewriter.clone(*op);
580       rewriter.replaceOp(op, cloned->getResults());
581     }
582     return success();
583   }
584 };
585 
586 void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
587                                                 MLIRContext *context) {
588   results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
589 }
590 
591 //===----------------------------------------------------------------------===//
592 // AssumeAlignmentOp
593 //===----------------------------------------------------------------------===//
594 
595 LogicalResult AssumeAlignmentOp::verify() {
596   if (!llvm::isPowerOf2_32(getAlignment()))
597     return emitOpError("alignment must be power of 2");
598   return success();
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // CastOp
603 //===----------------------------------------------------------------------===//
604 
605 void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
606   setNameFn(getResult(), "cast");
607 }
608 
609 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
610 /// source memref. This is useful to fold a memref.cast into a consuming op
611 /// and implement canonicalization patterns for ops in different dialects that
612 /// may consume the results of memref.cast operations. Such foldable memref.cast
613 /// operations are typically inserted as `view` and `subview` ops are
614 /// canonicalized, to preserve the type compatibility of their uses.
615 ///
616 /// Returns true when all conditions are met:
617 /// 1. source and result are ranked memrefs with strided semantics and same
618 /// element type and rank.
619 /// 2. each of the source's size, offset or stride has more static information
620 /// than the corresponding result's size, offset or stride.
621 ///
622 /// Example 1:
623 /// ```mlir
624 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
625 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
626 /// ```
627 ///
628 /// may fold into:
629 ///
630 /// ```mlir
631 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
632 /// ```
633 ///
634 /// Example 2:
635 /// ```
636 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
637 ///          to memref<?x?xf32>
638 ///   consumer %1 : memref<?x?xf32> ...
639 /// ```
640 ///
641 /// may fold into:
642 ///
643 /// ```
644 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
645 /// ```
646 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
647   MemRefType sourceType =
648       llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
649   MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
650 
651   // Requires ranked MemRefType.
652   if (!sourceType || !resultType)
653     return false;
654 
655   // Requires same elemental type.
656   if (sourceType.getElementType() != resultType.getElementType())
657     return false;
658 
659   // Requires same rank.
660   if (sourceType.getRank() != resultType.getRank())
661     return false;
662 
663   // Only fold casts between strided memref forms.
664   int64_t sourceOffset, resultOffset;
665   SmallVector<int64_t, 4> sourceStrides, resultStrides;
666   if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
667       failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
668     return false;
669 
670   // If cast is towards more static sizes along any dimension, don't fold.
671   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
672     auto ss = std::get<0>(it), st = std::get<1>(it);
673     if (ss != st)
674       if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
675         return false;
676   }
677 
678   // If cast is towards more static offset along any dimension, don't fold.
679   if (sourceOffset != resultOffset)
680     if (ShapedType::isDynamic(sourceOffset) &&
681         !ShapedType::isDynamic(resultOffset))
682       return false;
683 
684   // If cast is towards more static strides along any dimension, don't fold.
685   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
686     auto ss = std::get<0>(it), st = std::get<1>(it);
687     if (ss != st)
688       if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
689         return false;
690   }
691 
692   return true;
693 }
694 
695 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
696   if (inputs.size() != 1 || outputs.size() != 1)
697     return false;
698   Type a = inputs.front(), b = outputs.front();
699   auto aT = llvm::dyn_cast<MemRefType>(a);
700   auto bT = llvm::dyn_cast<MemRefType>(b);
701 
702   auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
703   auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
704 
705   if (aT && bT) {
706     if (aT.getElementType() != bT.getElementType())
707       return false;
708     if (aT.getLayout() != bT.getLayout()) {
709       int64_t aOffset, bOffset;
710       SmallVector<int64_t, 4> aStrides, bStrides;
711       if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
712           failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
713           aStrides.size() != bStrides.size())
714         return false;
715 
716       // Strides along a dimension/offset are compatible if the value in the
717       // source memref is static and the value in the target memref is the
718       // same. They are also compatible if either one is dynamic (see
719       // description of MemRefCastOp for details).
720       auto checkCompatible = [](int64_t a, int64_t b) {
721         return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
722       };
723       if (!checkCompatible(aOffset, bOffset))
724         return false;
725       for (const auto &aStride : enumerate(aStrides))
726         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
727           return false;
728     }
729     if (aT.getMemorySpace() != bT.getMemorySpace())
730       return false;
731 
732     // They must have the same rank, and any specified dimensions must match.
733     if (aT.getRank() != bT.getRank())
734       return false;
735 
736     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
737       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
738       if (!ShapedType::isDynamic(aDim) && !ShapedType::isDynamic(bDim) &&
739           aDim != bDim)
740         return false;
741     }
742     return true;
743   } else {
744     if (!aT && !uaT)
745       return false;
746     if (!bT && !ubT)
747       return false;
748     // Unranked to unranked casting is unsupported
749     if (uaT && ubT)
750       return false;
751 
752     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
753     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
754     if (aEltType != bEltType)
755       return false;
756 
757     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
758     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
759     return aMemSpace == bMemSpace;
760   }
761 
762   return false;
763 }
764 
765 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
766   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
767 }
768 
769 //===----------------------------------------------------------------------===//
770 // CopyOp
771 //===----------------------------------------------------------------------===//
772 
773 namespace {
774 /// If the source/target of a CopyOp is a CastOp that does not modify the shape
775 /// and element type, the cast can be skipped. Such CastOps only cast the layout
776 /// of the type.
777 struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
778   using OpRewritePattern<CopyOp>::OpRewritePattern;
779 
780   LogicalResult matchAndRewrite(CopyOp copyOp,
781                                 PatternRewriter &rewriter) const override {
782     bool modified = false;
783 
784     // Check source.
785     if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
786       auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
787       auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
788 
789       if (fromType && toType) {
790         if (fromType.getShape() == toType.getShape() &&
791             fromType.getElementType() == toType.getElementType()) {
792           rewriter.modifyOpInPlace(copyOp, [&] {
793             copyOp.getSourceMutable().assign(castOp.getSource());
794           });
795           modified = true;
796         }
797       }
798     }
799 
800     // Check target.
801     if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
802       auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
803       auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
804 
805       if (fromType && toType) {
806         if (fromType.getShape() == toType.getShape() &&
807             fromType.getElementType() == toType.getElementType()) {
808           rewriter.modifyOpInPlace(copyOp, [&] {
809             copyOp.getTargetMutable().assign(castOp.getSource());
810           });
811           modified = true;
812         }
813       }
814     }
815 
816     return success(modified);
817   }
818 };
819 
820 /// Fold memref.copy(%x, %x).
821 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
822   using OpRewritePattern<CopyOp>::OpRewritePattern;
823 
824   LogicalResult matchAndRewrite(CopyOp copyOp,
825                                 PatternRewriter &rewriter) const override {
826     if (copyOp.getSource() != copyOp.getTarget())
827       return failure();
828 
829     rewriter.eraseOp(copyOp);
830     return success();
831   }
832 };
833 
834 struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
835   using OpRewritePattern<CopyOp>::OpRewritePattern;
836 
837   static bool isEmptyMemRef(BaseMemRefType type) {
838     return type.hasRank() && llvm::is_contained(type.getShape(), 0);
839   }
840 
841   LogicalResult matchAndRewrite(CopyOp copyOp,
842                                 PatternRewriter &rewriter) const override {
843     if (isEmptyMemRef(copyOp.getSource().getType()) ||
844         isEmptyMemRef(copyOp.getTarget().getType())) {
845       rewriter.eraseOp(copyOp);
846       return success();
847     }
848 
849     return failure();
850   }
851 };
852 } // namespace
853 
854 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
855                                          MLIRContext *context) {
856   results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
857 }
858 
859 LogicalResult CopyOp::fold(FoldAdaptor adaptor,
860                            SmallVectorImpl<OpFoldResult> &results) {
861   /// copy(memrefcast) -> copy
862   bool folded = false;
863   Operation *op = *this;
864   for (OpOperand &operand : op->getOpOperands()) {
865     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
866     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
867       operand.set(castOp.getOperand());
868       folded = true;
869     }
870   }
871   return success(folded);
872 }
873 
874 //===----------------------------------------------------------------------===//
875 // DeallocOp
876 //===----------------------------------------------------------------------===//
877 
878 LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
879                               SmallVectorImpl<OpFoldResult> &results) {
880   /// dealloc(memrefcast) -> dealloc
881   return foldMemRefCast(*this);
882 }
883 
884 //===----------------------------------------------------------------------===//
885 // DimOp
886 //===----------------------------------------------------------------------===//
887 
888 void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
889   setNameFn(getResult(), "dim");
890 }
891 
892 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
893                   int64_t index) {
894   auto loc = result.location;
895   Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
896   build(builder, result, source, indexValue);
897 }
898 
899 std::optional<int64_t> DimOp::getConstantIndex() {
900   return getConstantIntValue(getIndex());
901 }
902 
903 Speculation::Speculatability DimOp::getSpeculatability() {
904   auto constantIndex = getConstantIndex();
905   if (!constantIndex)
906     return Speculation::NotSpeculatable;
907 
908   auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
909   if (!rankedSourceType)
910     return Speculation::NotSpeculatable;
911 
912   if (rankedSourceType.getRank() <= constantIndex)
913     return Speculation::NotSpeculatable;
914 
915   return Speculation::Speculatable;
916 }
917 
918 /// Return a map with key being elements in `vals` and data being number of
919 /// occurences of it. Use std::map, since the `vals` here are strides and the
920 /// dynamic stride value is the same as the tombstone value for
921 /// `DenseMap<int64_t>`.
922 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
923   std::map<int64_t, unsigned> numOccurences;
924   for (auto val : vals)
925     numOccurences[val]++;
926   return numOccurences;
927 }
928 
929 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
930 /// to be a subset of `originalType` with some `1` entries erased, return the
931 /// set of indices that specifies which of the entries of `originalShape` are
932 /// dropped to obtain `reducedShape`.
933 /// This accounts for cases where there are multiple unit-dims, but only a
934 /// subset of those are dropped. For MemRefTypes these can be disambiguated
935 /// using the strides. If a dimension is dropped the stride must be dropped too.
936 static FailureOr<llvm::SmallBitVector>
937 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
938                                ArrayRef<OpFoldResult> sizes) {
939   llvm::SmallBitVector unusedDims(originalType.getRank());
940   if (originalType.getRank() == reducedType.getRank())
941     return unusedDims;
942 
943   for (const auto &dim : llvm::enumerate(sizes))
944     if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
945       if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
946         unusedDims.set(dim.index());
947 
948   // Early exit for the case where the number of unused dims matches the number
949   // of ranks reduced.
950   if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
951       originalType.getRank())
952     return unusedDims;
953 
954   SmallVector<int64_t> originalStrides, candidateStrides;
955   int64_t originalOffset, candidateOffset;
956   if (failed(
957           originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
958       failed(
959           reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
960     return failure();
961 
962   // For memrefs, a dimension is truly dropped if its corresponding stride is
963   // also dropped. This is particularly important when more than one of the dims
964   // is 1. Track the number of occurences of the strides in the original type
965   // and the candidate type. For each unused dim that stride should not be
966   // present in the candidate type. Note that there could be multiple dimensions
967   // that have the same size. We dont need to exactly figure out which dim
968   // corresponds to which stride, we just need to verify that the number of
969   // reptitions of a stride in the original + number of unused dims with that
970   // stride == number of repititions of a stride in the candidate.
971   std::map<int64_t, unsigned> currUnaccountedStrides =
972       getNumOccurences(originalStrides);
973   std::map<int64_t, unsigned> candidateStridesNumOccurences =
974       getNumOccurences(candidateStrides);
975   for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
976     if (!unusedDims.test(dim))
977       continue;
978     int64_t originalStride = originalStrides[dim];
979     if (currUnaccountedStrides[originalStride] >
980         candidateStridesNumOccurences[originalStride]) {
981       // This dim can be treated as dropped.
982       currUnaccountedStrides[originalStride]--;
983       continue;
984     }
985     if (currUnaccountedStrides[originalStride] ==
986         candidateStridesNumOccurences[originalStride]) {
987       // The stride for this is not dropped. Keep as is.
988       unusedDims.reset(dim);
989       continue;
990     }
991     if (currUnaccountedStrides[originalStride] <
992         candidateStridesNumOccurences[originalStride]) {
993       // This should never happen. Cant have a stride in the reduced rank type
994       // that wasnt in the original one.
995       return failure();
996     }
997   }
998 
999   if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1000       originalType.getRank())
1001     return failure();
1002   return unusedDims;
1003 }
1004 
1005 llvm::SmallBitVector SubViewOp::getDroppedDims() {
1006   MemRefType sourceType = getSourceType();
1007   MemRefType resultType = getType();
1008   FailureOr<llvm::SmallBitVector> unusedDims =
1009       computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1010   assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1011   return *unusedDims;
1012 }
1013 
1014 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1015   // All forms of folding require a known index.
1016   auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1017   if (!index)
1018     return {};
1019 
1020   // Folding for unranked types (UnrankedMemRefType) is not supported.
1021   auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1022   if (!memrefType)
1023     return {};
1024 
1025   // Out of bound indices produce undefined behavior but are still valid IR.
1026   // Don't choke on them.
1027   int64_t indexVal = index.getInt();
1028   if (indexVal < 0 || indexVal >= memrefType.getRank())
1029     return {};
1030 
1031   // Fold if the shape extent along the given index is known.
1032   if (!memrefType.isDynamicDim(index.getInt())) {
1033     Builder builder(getContext());
1034     return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1035   }
1036 
1037   // The size at the given index is now known to be a dynamic size.
1038   unsigned unsignedIndex = index.getValue().getZExtValue();
1039 
1040   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1041   Operation *definingOp = getSource().getDefiningOp();
1042 
1043   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1044     return *(alloc.getDynamicSizes().begin() +
1045              memrefType.getDynamicDimIndex(unsignedIndex));
1046 
1047   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1048     return *(alloca.getDynamicSizes().begin() +
1049              memrefType.getDynamicDimIndex(unsignedIndex));
1050 
1051   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1052     return *(view.getDynamicSizes().begin() +
1053              memrefType.getDynamicDimIndex(unsignedIndex));
1054 
1055   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1056     llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1057     unsigned resultIndex = 0;
1058     unsigned sourceRank = subview.getSourceType().getRank();
1059     unsigned sourceIndex = 0;
1060     for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1061       if (unusedDims.test(i))
1062         continue;
1063       if (resultIndex == unsignedIndex) {
1064         sourceIndex = i;
1065         break;
1066       }
1067       resultIndex++;
1068     }
1069     assert(subview.isDynamicSize(sourceIndex) &&
1070            "expected dynamic subview size");
1071     return subview.getDynamicSize(sourceIndex);
1072   }
1073 
1074   if (auto sizeInterface =
1075           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1076     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1077            "Expected dynamic subview size");
1078     return sizeInterface.getDynamicSize(unsignedIndex);
1079   }
1080 
1081   // dim(memrefcast) -> dim
1082   if (succeeded(foldMemRefCast(*this)))
1083     return getResult();
1084 
1085   return {};
1086 }
1087 
1088 namespace {
1089 /// Fold dim of a memref reshape operation to a load into the reshape's shape
1090 /// operand.
1091 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1092   using OpRewritePattern<DimOp>::OpRewritePattern;
1093 
1094   LogicalResult matchAndRewrite(DimOp dim,
1095                                 PatternRewriter &rewriter) const override {
1096     auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1097 
1098     if (!reshape)
1099       return rewriter.notifyMatchFailure(
1100           dim, "Dim op is not defined by a reshape op.");
1101 
1102     // dim of a memref reshape can be folded if dim.getIndex() dominates the
1103     // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1104     // cheaply check that either of the following conditions hold:
1105     //      1. dim.getIndex() is defined in the same block as reshape but before
1106     //      reshape.
1107     //      2. dim.getIndex() is defined in a parent block of
1108     //      reshape.
1109 
1110     // Check condition 1
1111     if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1112       if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1113         if (reshape->isBeforeInBlock(definingOp)) {
1114           return rewriter.notifyMatchFailure(
1115               dim,
1116               "dim.getIndex is not defined before reshape in the same block.");
1117         }
1118       } // else dim.getIndex is a block argument to reshape->getBlock and
1119         // dominates reshape
1120     }   // Check condition 2
1121     else if (dim->getBlock() != reshape->getBlock() &&
1122              !dim.getIndex().getParentRegion()->isProperAncestor(
1123                  reshape->getParentRegion())) {
1124       // If dim and reshape are in the same block but dim.getIndex() isn't, we
1125       // already know dim.getIndex() dominates reshape without calling
1126       // `isProperAncestor`
1127       return rewriter.notifyMatchFailure(
1128           dim, "dim.getIndex does not dominate reshape.");
1129     }
1130 
1131     // Place the load directly after the reshape to ensure that the shape memref
1132     // was not mutated.
1133     rewriter.setInsertionPointAfter(reshape);
1134     Location loc = dim.getLoc();
1135     Value load =
1136         rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
1137     if (load.getType() != dim.getType())
1138       load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
1139     rewriter.replaceOp(dim, load);
1140     return success();
1141   }
1142 };
1143 
1144 } // namespace
1145 
1146 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1147                                         MLIRContext *context) {
1148   results.add<DimOfMemRefReshape>(context);
1149 }
1150 
1151 // ---------------------------------------------------------------------------
1152 // DmaStartOp
1153 // ---------------------------------------------------------------------------
1154 
1155 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1156                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1157                        ValueRange destIndices, Value numElements,
1158                        Value tagMemRef, ValueRange tagIndices, Value stride,
1159                        Value elementsPerStride) {
1160   result.addOperands(srcMemRef);
1161   result.addOperands(srcIndices);
1162   result.addOperands(destMemRef);
1163   result.addOperands(destIndices);
1164   result.addOperands({numElements, tagMemRef});
1165   result.addOperands(tagIndices);
1166   if (stride)
1167     result.addOperands({stride, elementsPerStride});
1168 }
1169 
1170 void DmaStartOp::print(OpAsmPrinter &p) {
1171   p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1172     << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1173     << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1174   if (isStrided())
1175     p << ", " << getStride() << ", " << getNumElementsPerStride();
1176 
1177   p.printOptionalAttrDict((*this)->getAttrs());
1178   p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1179     << ", " << getTagMemRef().getType();
1180 }
1181 
1182 // Parse DmaStartOp.
1183 // Ex:
1184 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1185 //                       %tag[%index], %stride, %num_elt_per_stride :
1186 //                     : memref<3076 x f32, 0>,
1187 //                       memref<1024 x f32, 2>,
1188 //                       memref<1 x i32>
1189 //
1190 ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1191   OpAsmParser::UnresolvedOperand srcMemRefInfo;
1192   SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1193   OpAsmParser::UnresolvedOperand dstMemRefInfo;
1194   SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1195   OpAsmParser::UnresolvedOperand numElementsInfo;
1196   OpAsmParser::UnresolvedOperand tagMemrefInfo;
1197   SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1198   SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1199 
1200   SmallVector<Type, 3> types;
1201   auto indexType = parser.getBuilder().getIndexType();
1202 
1203   // Parse and resolve the following list of operands:
1204   // *) source memref followed by its indices (in square brackets).
1205   // *) destination memref followed by its indices (in square brackets).
1206   // *) dma size in KiB.
1207   if (parser.parseOperand(srcMemRefInfo) ||
1208       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1209       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1210       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1211       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1212       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1213       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1214     return failure();
1215 
1216   // Parse optional stride and elements per stride.
1217   if (parser.parseTrailingOperandList(strideInfo))
1218     return failure();
1219 
1220   bool isStrided = strideInfo.size() == 2;
1221   if (!strideInfo.empty() && !isStrided) {
1222     return parser.emitError(parser.getNameLoc(),
1223                             "expected two stride related operands");
1224   }
1225 
1226   if (parser.parseColonTypeList(types))
1227     return failure();
1228   if (types.size() != 3)
1229     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1230 
1231   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1232       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1233       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1234       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1235       // size should be an index.
1236       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1237       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1238       // tag indices should be index.
1239       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1240     return failure();
1241 
1242   if (isStrided) {
1243     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1244       return failure();
1245   }
1246 
1247   return success();
1248 }
1249 
1250 LogicalResult DmaStartOp::verify() {
1251   unsigned numOperands = getNumOperands();
1252 
1253   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1254   // the number of elements.
1255   if (numOperands < 4)
1256     return emitOpError("expected at least 4 operands");
1257 
1258   // Check types of operands. The order of these calls is important: the later
1259   // calls rely on some type properties to compute the operand position.
1260   // 1. Source memref.
1261   if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1262     return emitOpError("expected source to be of memref type");
1263   if (numOperands < getSrcMemRefRank() + 4)
1264     return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1265                          << " operands";
1266   if (!getSrcIndices().empty() &&
1267       !llvm::all_of(getSrcIndices().getTypes(),
1268                     [](Type t) { return t.isIndex(); }))
1269     return emitOpError("expected source indices to be of index type");
1270 
1271   // 2. Destination memref.
1272   if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1273     return emitOpError("expected destination to be of memref type");
1274   unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1275   if (numOperands < numExpectedOperands)
1276     return emitOpError() << "expected at least " << numExpectedOperands
1277                          << " operands";
1278   if (!getDstIndices().empty() &&
1279       !llvm::all_of(getDstIndices().getTypes(),
1280                     [](Type t) { return t.isIndex(); }))
1281     return emitOpError("expected destination indices to be of index type");
1282 
1283   // 3. Number of elements.
1284   if (!getNumElements().getType().isIndex())
1285     return emitOpError("expected num elements to be of index type");
1286 
1287   // 4. Tag memref.
1288   if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1289     return emitOpError("expected tag to be of memref type");
1290   numExpectedOperands += getTagMemRefRank();
1291   if (numOperands < numExpectedOperands)
1292     return emitOpError() << "expected at least " << numExpectedOperands
1293                          << " operands";
1294   if (!getTagIndices().empty() &&
1295       !llvm::all_of(getTagIndices().getTypes(),
1296                     [](Type t) { return t.isIndex(); }))
1297     return emitOpError("expected tag indices to be of index type");
1298 
1299   // Optional stride-related operands must be either both present or both
1300   // absent.
1301   if (numOperands != numExpectedOperands &&
1302       numOperands != numExpectedOperands + 2)
1303     return emitOpError("incorrect number of operands");
1304 
1305   // 5. Strides.
1306   if (isStrided()) {
1307     if (!getStride().getType().isIndex() ||
1308         !getNumElementsPerStride().getType().isIndex())
1309       return emitOpError(
1310           "expected stride and num elements per stride to be of type index");
1311   }
1312 
1313   return success();
1314 }
1315 
1316 LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1317                                SmallVectorImpl<OpFoldResult> &results) {
1318   /// dma_start(memrefcast) -> dma_start
1319   return foldMemRefCast(*this);
1320 }
1321 
1322 // ---------------------------------------------------------------------------
1323 // DmaWaitOp
1324 // ---------------------------------------------------------------------------
1325 
1326 LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1327                               SmallVectorImpl<OpFoldResult> &results) {
1328   /// dma_wait(memrefcast) -> dma_wait
1329   return foldMemRefCast(*this);
1330 }
1331 
1332 LogicalResult DmaWaitOp::verify() {
1333   // Check that the number of tag indices matches the tagMemRef rank.
1334   unsigned numTagIndices = getTagIndices().size();
1335   unsigned tagMemRefRank = getTagMemRefRank();
1336   if (numTagIndices != tagMemRefRank)
1337     return emitOpError() << "expected tagIndices to have the same number of "
1338                             "elements as the tagMemRef rank, expected "
1339                          << tagMemRefRank << ", but got " << numTagIndices;
1340   return success();
1341 }
1342 
1343 //===----------------------------------------------------------------------===//
1344 // ExtractAlignedPointerAsIndexOp
1345 //===----------------------------------------------------------------------===//
1346 
1347 void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1348     function_ref<void(Value, StringRef)> setNameFn) {
1349   setNameFn(getResult(), "intptr");
1350 }
1351 
1352 //===----------------------------------------------------------------------===//
1353 // ExtractStridedMetadataOp
1354 //===----------------------------------------------------------------------===//
1355 
1356 /// The number and type of the results are inferred from the
1357 /// shape of the source.
1358 LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1359     MLIRContext *context, std::optional<Location> location,
1360     ExtractStridedMetadataOp::Adaptor adaptor,
1361     SmallVectorImpl<Type> &inferredReturnTypes) {
1362   auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1363   if (!sourceType)
1364     return failure();
1365 
1366   unsigned sourceRank = sourceType.getRank();
1367   IndexType indexType = IndexType::get(context);
1368   auto memrefType =
1369       MemRefType::get({}, sourceType.getElementType(),
1370                       MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1371   // Base.
1372   inferredReturnTypes.push_back(memrefType);
1373   // Offset.
1374   inferredReturnTypes.push_back(indexType);
1375   // Sizes and strides.
1376   for (unsigned i = 0; i < sourceRank * 2; ++i)
1377     inferredReturnTypes.push_back(indexType);
1378   return success();
1379 }
1380 
1381 void ExtractStridedMetadataOp::getAsmResultNames(
1382     function_ref<void(Value, StringRef)> setNameFn) {
1383   setNameFn(getBaseBuffer(), "base_buffer");
1384   setNameFn(getOffset(), "offset");
1385   // For multi-result to work properly with pretty names and packed syntax `x:3`
1386   // we can only give a pretty name to the first value in the pack.
1387   if (!getSizes().empty()) {
1388     setNameFn(getSizes().front(), "sizes");
1389     setNameFn(getStrides().front(), "strides");
1390   }
1391 }
1392 
1393 /// Helper function to perform the replacement of all constant uses of `values`
1394 /// by a materialized constant extracted from `maybeConstants`.
1395 /// `values` and `maybeConstants` are expected to have the same size.
1396 template <typename Container>
1397 static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1398                                   Container values,
1399                                   ArrayRef<OpFoldResult> maybeConstants) {
1400   assert(values.size() == maybeConstants.size() &&
1401          " expected values and maybeConstants of the same size");
1402   bool atLeastOneReplacement = false;
1403   for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1404     // Don't materialize a constant if there are no uses: this would indice
1405     // infinite loops in the driver.
1406     if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1407       continue;
1408     assert(isa<Attribute>(maybeConstant) &&
1409            "The constified value should be either unchanged (i.e., == result) "
1410            "or a constant");
1411     Value constantVal = rewriter.create<arith::ConstantIndexOp>(
1412         loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1413     for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1414       // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1415       // yet.
1416       op->replaceUsesOfWith(result, constantVal);
1417       atLeastOneReplacement = true;
1418     }
1419   }
1420   return atLeastOneReplacement;
1421 }
1422 
1423 LogicalResult
1424 ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1425                                SmallVectorImpl<OpFoldResult> &results) {
1426   OpBuilder builder(*this);
1427 
1428   bool atLeastOneReplacement = replaceConstantUsesOf(
1429       builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1430       getConstifiedMixedOffset());
1431   atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1432                                                  getConstifiedMixedSizes());
1433   atLeastOneReplacement |= replaceConstantUsesOf(
1434       builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1435 
1436   return success(atLeastOneReplacement);
1437 }
1438 
1439 SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1440   SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1441   constifyIndexValues(values, getSource().getType(), getContext(),
1442                       getConstantSizes, ShapedType::isDynamic);
1443   return values;
1444 }
1445 
1446 SmallVector<OpFoldResult>
1447 ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1448   SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1449   constifyIndexValues(values, getSource().getType(), getContext(),
1450                       getConstantStrides, ShapedType::isDynamic);
1451   return values;
1452 }
1453 
1454 OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1455   OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1456   SmallVector<OpFoldResult> values(1, offsetOfr);
1457   constifyIndexValues(values, getSource().getType(), getContext(),
1458                       getConstantOffset, ShapedType::isDynamic);
1459   return values[0];
1460 }
1461 
1462 //===----------------------------------------------------------------------===//
1463 // GenericAtomicRMWOp
1464 //===----------------------------------------------------------------------===//
1465 
1466 void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1467                                Value memref, ValueRange ivs) {
1468   OpBuilder::InsertionGuard g(builder);
1469   result.addOperands(memref);
1470   result.addOperands(ivs);
1471 
1472   if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1473     Type elementType = memrefType.getElementType();
1474     result.addTypes(elementType);
1475 
1476     Region *bodyRegion = result.addRegion();
1477     builder.createBlock(bodyRegion);
1478     bodyRegion->addArgument(elementType, memref.getLoc());
1479   }
1480 }
1481 
1482 LogicalResult GenericAtomicRMWOp::verify() {
1483   auto &body = getRegion();
1484   if (body.getNumArguments() != 1)
1485     return emitOpError("expected single number of entry block arguments");
1486 
1487   if (getResult().getType() != body.getArgument(0).getType())
1488     return emitOpError("expected block argument of the same type result type");
1489 
1490   bool hasSideEffects =
1491       body.walk([&](Operation *nestedOp) {
1492             if (isMemoryEffectFree(nestedOp))
1493               return WalkResult::advance();
1494             nestedOp->emitError(
1495                 "body of 'memref.generic_atomic_rmw' should contain "
1496                 "only operations with no side effects");
1497             return WalkResult::interrupt();
1498           })
1499           .wasInterrupted();
1500   return hasSideEffects ? failure() : success();
1501 }
1502 
1503 ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1504                                       OperationState &result) {
1505   OpAsmParser::UnresolvedOperand memref;
1506   Type memrefType;
1507   SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1508 
1509   Type indexType = parser.getBuilder().getIndexType();
1510   if (parser.parseOperand(memref) ||
1511       parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
1512       parser.parseColonType(memrefType) ||
1513       parser.resolveOperand(memref, memrefType, result.operands) ||
1514       parser.resolveOperands(ivs, indexType, result.operands))
1515     return failure();
1516 
1517   Region *body = result.addRegion();
1518   if (parser.parseRegion(*body, {}) ||
1519       parser.parseOptionalAttrDict(result.attributes))
1520     return failure();
1521   result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1522   return success();
1523 }
1524 
1525 void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1526   p << ' ' << getMemref() << "[" << getIndices()
1527     << "] : " << getMemref().getType() << ' ';
1528   p.printRegion(getRegion());
1529   p.printOptionalAttrDict((*this)->getAttrs());
1530 }
1531 
1532 //===----------------------------------------------------------------------===//
1533 // AtomicYieldOp
1534 //===----------------------------------------------------------------------===//
1535 
1536 LogicalResult AtomicYieldOp::verify() {
1537   Type parentType = (*this)->getParentOp()->getResultTypes().front();
1538   Type resultType = getResult().getType();
1539   if (parentType != resultType)
1540     return emitOpError() << "types mismatch between yield op: " << resultType
1541                          << " and its parent: " << parentType;
1542   return success();
1543 }
1544 
1545 //===----------------------------------------------------------------------===//
1546 // GlobalOp
1547 //===----------------------------------------------------------------------===//
1548 
1549 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1550                                                    TypeAttr type,
1551                                                    Attribute initialValue) {
1552   p << type;
1553   if (!op.isExternal()) {
1554     p << " = ";
1555     if (op.isUninitialized())
1556       p << "uninitialized";
1557     else
1558       p.printAttributeWithoutType(initialValue);
1559   }
1560 }
1561 
1562 static ParseResult
1563 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1564                                        Attribute &initialValue) {
1565   Type type;
1566   if (parser.parseType(type))
1567     return failure();
1568 
1569   auto memrefType = llvm::dyn_cast<MemRefType>(type);
1570   if (!memrefType || !memrefType.hasStaticShape())
1571     return parser.emitError(parser.getNameLoc())
1572            << "type should be static shaped memref, but got " << type;
1573   typeAttr = TypeAttr::get(type);
1574 
1575   if (parser.parseOptionalEqual())
1576     return success();
1577 
1578   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1579     initialValue = UnitAttr::get(parser.getContext());
1580     return success();
1581   }
1582 
1583   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1584   if (parser.parseAttribute(initialValue, tensorType))
1585     return failure();
1586   if (!llvm::isa<ElementsAttr>(initialValue))
1587     return parser.emitError(parser.getNameLoc())
1588            << "initial value should be a unit or elements attribute";
1589   return success();
1590 }
1591 
1592 LogicalResult GlobalOp::verify() {
1593   auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1594   if (!memrefType || !memrefType.hasStaticShape())
1595     return emitOpError("type should be static shaped memref, but got ")
1596            << getType();
1597 
1598   // Verify that the initial value, if present, is either a unit attribute or
1599   // an elements attribute.
1600   if (getInitialValue().has_value()) {
1601     Attribute initValue = getInitialValue().value();
1602     if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1603       return emitOpError("initial value should be a unit or elements "
1604                          "attribute, but got ")
1605              << initValue;
1606 
1607     // Check that the type of the initial value is compatible with the type of
1608     // the global variable.
1609     if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1610       Type initType = elementsAttr.getType();
1611       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1612       if (initType != tensorType)
1613         return emitOpError("initial value expected to be of type ")
1614                << tensorType << ", but was of type " << initType;
1615     }
1616   }
1617 
1618   if (std::optional<uint64_t> alignAttr = getAlignment()) {
1619     uint64_t alignment = *alignAttr;
1620 
1621     if (!llvm::isPowerOf2_64(alignment))
1622       return emitError() << "alignment attribute value " << alignment
1623                          << " is not a power of 2";
1624   }
1625 
1626   // TODO: verify visibility for declarations.
1627   return success();
1628 }
1629 
1630 ElementsAttr GlobalOp::getConstantInitValue() {
1631   auto initVal = getInitialValue();
1632   if (getConstant() && initVal.has_value())
1633     return llvm::cast<ElementsAttr>(initVal.value());
1634   return {};
1635 }
1636 
1637 //===----------------------------------------------------------------------===//
1638 // GetGlobalOp
1639 //===----------------------------------------------------------------------===//
1640 
1641 LogicalResult
1642 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1643   // Verify that the result type is same as the type of the referenced
1644   // memref.global op.
1645   auto global =
1646       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1647   if (!global)
1648     return emitOpError("'")
1649            << getName() << "' does not reference a valid global memref";
1650 
1651   Type resultType = getResult().getType();
1652   if (global.getType() != resultType)
1653     return emitOpError("result type ")
1654            << resultType << " does not match type " << global.getType()
1655            << " of the global memref @" << getName();
1656   return success();
1657 }
1658 
1659 //===----------------------------------------------------------------------===//
1660 // LoadOp
1661 //===----------------------------------------------------------------------===//
1662 
1663 LogicalResult LoadOp::verify() {
1664   if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1665     return emitOpError("incorrect number of indices for load, expected ")
1666            << getMemRefType().getRank() << " but got " << getIndices().size();
1667   }
1668   return success();
1669 }
1670 
1671 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1672   /// load(memrefcast) -> load
1673   if (succeeded(foldMemRefCast(*this)))
1674     return getResult();
1675   return OpFoldResult();
1676 }
1677 
1678 //===----------------------------------------------------------------------===//
1679 // MemorySpaceCastOp
1680 //===----------------------------------------------------------------------===//
1681 
1682 void MemorySpaceCastOp::getAsmResultNames(
1683     function_ref<void(Value, StringRef)> setNameFn) {
1684   setNameFn(getResult(), "memspacecast");
1685 }
1686 
1687 bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1688   if (inputs.size() != 1 || outputs.size() != 1)
1689     return false;
1690   Type a = inputs.front(), b = outputs.front();
1691   auto aT = llvm::dyn_cast<MemRefType>(a);
1692   auto bT = llvm::dyn_cast<MemRefType>(b);
1693 
1694   auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1695   auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1696 
1697   if (aT && bT) {
1698     if (aT.getElementType() != bT.getElementType())
1699       return false;
1700     if (aT.getLayout() != bT.getLayout())
1701       return false;
1702     if (aT.getShape() != bT.getShape())
1703       return false;
1704     return true;
1705   }
1706   if (uaT && ubT) {
1707     return uaT.getElementType() == ubT.getElementType();
1708   }
1709   return false;
1710 }
1711 
1712 OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1713   // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1714   // t2)
1715   if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1716     getSourceMutable().assign(parentCast.getSource());
1717     return getResult();
1718   }
1719   return Value{};
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // PrefetchOp
1724 //===----------------------------------------------------------------------===//
1725 
1726 void PrefetchOp::print(OpAsmPrinter &p) {
1727   p << " " << getMemref() << '[';
1728   p.printOperands(getIndices());
1729   p << ']' << ", " << (getIsWrite() ? "write" : "read");
1730   p << ", locality<" << getLocalityHint();
1731   p << ">, " << (getIsDataCache() ? "data" : "instr");
1732   p.printOptionalAttrDict(
1733       (*this)->getAttrs(),
1734       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1735   p << " : " << getMemRefType();
1736 }
1737 
1738 ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1739   OpAsmParser::UnresolvedOperand memrefInfo;
1740   SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1741   IntegerAttr localityHint;
1742   MemRefType type;
1743   StringRef readOrWrite, cacheType;
1744 
1745   auto indexTy = parser.getBuilder().getIndexType();
1746   auto i32Type = parser.getBuilder().getIntegerType(32);
1747   if (parser.parseOperand(memrefInfo) ||
1748       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1749       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1750       parser.parseComma() || parser.parseKeyword("locality") ||
1751       parser.parseLess() ||
1752       parser.parseAttribute(localityHint, i32Type, "localityHint",
1753                             result.attributes) ||
1754       parser.parseGreater() || parser.parseComma() ||
1755       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1756       parser.resolveOperand(memrefInfo, type, result.operands) ||
1757       parser.resolveOperands(indexInfo, indexTy, result.operands))
1758     return failure();
1759 
1760   if (readOrWrite != "read" && readOrWrite != "write")
1761     return parser.emitError(parser.getNameLoc(),
1762                             "rw specifier has to be 'read' or 'write'");
1763   result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1764                       parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1765 
1766   if (cacheType != "data" && cacheType != "instr")
1767     return parser.emitError(parser.getNameLoc(),
1768                             "cache type has to be 'data' or 'instr'");
1769 
1770   result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1771                       parser.getBuilder().getBoolAttr(cacheType == "data"));
1772 
1773   return success();
1774 }
1775 
1776 LogicalResult PrefetchOp::verify() {
1777   if (getNumOperands() != 1 + getMemRefType().getRank())
1778     return emitOpError("too few indices");
1779 
1780   return success();
1781 }
1782 
1783 LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1784                                SmallVectorImpl<OpFoldResult> &results) {
1785   // prefetch(memrefcast) -> prefetch
1786   return foldMemRefCast(*this);
1787 }
1788 
1789 //===----------------------------------------------------------------------===//
1790 // RankOp
1791 //===----------------------------------------------------------------------===//
1792 
1793 OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1794   // Constant fold rank when the rank of the operand is known.
1795   auto type = getOperand().getType();
1796   auto shapedType = llvm::dyn_cast<ShapedType>(type);
1797   if (shapedType && shapedType.hasRank())
1798     return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1799   return IntegerAttr();
1800 }
1801 
1802 //===----------------------------------------------------------------------===//
1803 // ReinterpretCastOp
1804 //===----------------------------------------------------------------------===//
1805 
1806 void ReinterpretCastOp::getAsmResultNames(
1807     function_ref<void(Value, StringRef)> setNameFn) {
1808   setNameFn(getResult(), "reinterpret_cast");
1809 }
1810 
1811 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1812 /// `staticSizes` and `staticStrides` are automatically filled with
1813 /// source-memref-rank sentinel values that encode dynamic entries.
1814 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1815                               MemRefType resultType, Value source,
1816                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1817                               ArrayRef<OpFoldResult> strides,
1818                               ArrayRef<NamedAttribute> attrs) {
1819   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1820   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1821   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1822   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1823   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1824   result.addAttributes(attrs);
1825   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1826         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1827         b.getDenseI64ArrayAttr(staticSizes),
1828         b.getDenseI64ArrayAttr(staticStrides));
1829 }
1830 
1831 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1832                               Value source, OpFoldResult offset,
1833                               ArrayRef<OpFoldResult> sizes,
1834                               ArrayRef<OpFoldResult> strides,
1835                               ArrayRef<NamedAttribute> attrs) {
1836   auto sourceType = cast<BaseMemRefType>(source.getType());
1837   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1838   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1839   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1840   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1841   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1842   auto stridedLayout = StridedLayoutAttr::get(
1843       b.getContext(), staticOffsets.front(), staticStrides);
1844   auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1845                                     stridedLayout, sourceType.getMemorySpace());
1846   build(b, result, resultType, source, offset, sizes, strides, attrs);
1847 }
1848 
1849 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1850                               MemRefType resultType, Value source,
1851                               int64_t offset, ArrayRef<int64_t> sizes,
1852                               ArrayRef<int64_t> strides,
1853                               ArrayRef<NamedAttribute> attrs) {
1854   SmallVector<OpFoldResult> sizeValues =
1855       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1856         return b.getI64IntegerAttr(v);
1857       }));
1858   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1859       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1860         return b.getI64IntegerAttr(v);
1861       }));
1862   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1863         strideValues, attrs);
1864 }
1865 
1866 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1867                               MemRefType resultType, Value source, Value offset,
1868                               ValueRange sizes, ValueRange strides,
1869                               ArrayRef<NamedAttribute> attrs) {
1870   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1871       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1872   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1873       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1874   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1875 }
1876 
1877 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1878 // completed automatically, like we have for subview and extract_slice.
1879 LogicalResult ReinterpretCastOp::verify() {
1880   // The source and result memrefs should be in the same memory space.
1881   auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1882   auto resultType = llvm::cast<MemRefType>(getType());
1883   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1884     return emitError("different memory spaces specified for source type ")
1885            << srcType << " and result memref type " << resultType;
1886   if (srcType.getElementType() != resultType.getElementType())
1887     return emitError("different element types specified for source type ")
1888            << srcType << " and result memref type " << resultType;
1889 
1890   // Match sizes in result memref type and in static_sizes attribute.
1891   for (auto [idx, resultSize, expectedSize] :
1892        llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1893     if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize)
1894       return emitError("expected result type with size = ")
1895              << (ShapedType::isDynamic(expectedSize)
1896                      ? std::string("dynamic")
1897                      : std::to_string(expectedSize))
1898              << " instead of " << resultSize << " in dim = " << idx;
1899   }
1900 
1901   // Match offset and strides in static_offset and static_strides attributes. If
1902   // result memref type has no affine map specified, this will assume an
1903   // identity layout.
1904   int64_t resultOffset;
1905   SmallVector<int64_t, 4> resultStrides;
1906   if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1907     return emitError("expected result type to have strided layout but found ")
1908            << resultType;
1909 
1910   // Match offset in result memref type and in static_offsets attribute.
1911   int64_t expectedOffset = getStaticOffsets().front();
1912   if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset)
1913     return emitError("expected result type with offset = ")
1914            << (ShapedType::isDynamic(expectedOffset)
1915                    ? std::string("dynamic")
1916                    : std::to_string(expectedOffset))
1917            << " instead of " << resultOffset;
1918 
1919   // Match strides in result memref type and in static_strides attribute.
1920   for (auto [idx, resultStride, expectedStride] :
1921        llvm::enumerate(resultStrides, getStaticStrides())) {
1922     if (!ShapedType::isDynamic(resultStride) && resultStride != expectedStride)
1923       return emitError("expected result type with stride = ")
1924              << (ShapedType::isDynamic(expectedStride)
1925                      ? std::string("dynamic")
1926                      : std::to_string(expectedStride))
1927              << " instead of " << resultStride << " in dim = " << idx;
1928   }
1929 
1930   return success();
1931 }
1932 
1933 OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1934   Value src = getSource();
1935   auto getPrevSrc = [&]() -> Value {
1936     // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1937     if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1938       return prev.getSource();
1939 
1940     // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1941     if (auto prev = src.getDefiningOp<CastOp>())
1942       return prev.getSource();
1943 
1944     // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1945     // are 0.
1946     if (auto prev = src.getDefiningOp<SubViewOp>())
1947       if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1948             return isConstantIntValue(val, 0);
1949           }))
1950         return prev.getSource();
1951 
1952     return nullptr;
1953   };
1954 
1955   if (auto prevSrc = getPrevSrc()) {
1956     getSourceMutable().assign(prevSrc);
1957     return getResult();
1958   }
1959 
1960   // reinterpret_cast(x) w/o offset/shape/stride changes -> x
1961   if (!ShapedType::isDynamicShape(getType().getShape()) &&
1962       src.getType() == getType() && getStaticOffsets().front() == 0) {
1963     return src;
1964   }
1965 
1966   return nullptr;
1967 }
1968 
1969 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
1970   SmallVector<OpFoldResult> values = getMixedSizes();
1971   constifyIndexValues(values, getType(), getContext(), getConstantSizes,
1972                       ShapedType::isDynamic);
1973   return values;
1974 }
1975 
1976 SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
1977   SmallVector<OpFoldResult> values = getMixedStrides();
1978   constifyIndexValues(values, getType(), getContext(), getConstantStrides,
1979                       ShapedType::isDynamic);
1980   return values;
1981 }
1982 
1983 OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
1984   SmallVector<OpFoldResult> values = getMixedOffsets();
1985   assert(values.size() == 1 &&
1986          "reinterpret_cast must have one and only one offset");
1987   constifyIndexValues(values, getType(), getContext(), getConstantOffset,
1988                       ShapedType::isDynamic);
1989   return values[0];
1990 }
1991 
1992 namespace {
1993 /// Replace the sequence:
1994 /// ```
1995 /// base, offset, sizes, strides = extract_strided_metadata src
1996 /// dst = reinterpret_cast base to offset, sizes, strides
1997 /// ```
1998 /// With
1999 ///
2000 /// ```
2001 /// dst = memref.cast src
2002 /// ```
2003 ///
2004 /// Note: The cast operation is only inserted when the type of dst and src
2005 /// are not the same. E.g., when going from <4xf32> to <?xf32>.
2006 ///
2007 /// This pattern also matches when the offset, sizes, and strides don't come
2008 /// directly from the `extract_strided_metadata`'s results but it can be
2009 /// statically proven that they would hold the same values.
2010 ///
2011 /// For instance, the following sequence would be replaced:
2012 /// ```
2013 /// base, offset, sizes, strides =
2014 ///   extract_strided_metadata memref : memref<3x4xty>
2015 /// dst = reinterpret_cast base to 0, [3, 4], strides
2016 /// ```
2017 /// Because we know (thanks to the type of the input memref) that variable
2018 /// `offset` and `sizes` will respectively hold 0 and [3, 4].
2019 ///
2020 /// Similarly, the following sequence would be replaced:
2021 /// ```
2022 /// c0 = arith.constant 0
2023 /// c4 = arith.constant 4
2024 /// base, offset, sizes, strides =
2025 ///   extract_strided_metadata memref : memref<3x4xty>
2026 /// dst = reinterpret_cast base to c0, [3, c4], strides
2027 /// ```
2028 /// Because we know that `offset`and `c0` will hold 0
2029 /// and `c4` will hold 4.
2030 struct ReinterpretCastOpExtractStridedMetadataFolder
2031     : public OpRewritePattern<ReinterpretCastOp> {
2032 public:
2033   using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2034 
2035   LogicalResult matchAndRewrite(ReinterpretCastOp op,
2036                                 PatternRewriter &rewriter) const override {
2037     auto extractStridedMetadata =
2038         op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2039     if (!extractStridedMetadata)
2040       return failure();
2041     // Check if the reinterpret cast reconstructs a memref with the exact same
2042     // properties as the extract strided metadata.
2043 
2044     // First, check that the strides are the same.
2045     SmallVector<OpFoldResult> extractStridesOfr =
2046         extractStridedMetadata.getConstifiedMixedStrides();
2047     SmallVector<OpFoldResult> reinterpretStridesOfr =
2048         op.getConstifiedMixedStrides();
2049     if (extractStridesOfr.size() != reinterpretStridesOfr.size())
2050       return failure();
2051 
2052     unsigned rank = op.getType().getRank();
2053     for (unsigned i = 0; i < rank; ++i) {
2054       if (extractStridesOfr[i] != reinterpretStridesOfr[i])
2055         return failure();
2056     }
2057 
2058     // Second, check the sizes.
2059     assert(extractStridedMetadata.getSizes().size() ==
2060                op.getMixedSizes().size() &&
2061            "Strides and sizes rank must match");
2062     SmallVector<OpFoldResult> extractSizesOfr =
2063         extractStridedMetadata.getConstifiedMixedSizes();
2064     SmallVector<OpFoldResult> reinterpretSizesOfr =
2065         op.getConstifiedMixedSizes();
2066     for (unsigned i = 0; i < rank; ++i) {
2067       if (extractSizesOfr[i] != reinterpretSizesOfr[i])
2068         return failure();
2069     }
2070     // Finally, check the offset.
2071     assert(op.getMixedOffsets().size() == 1 &&
2072            "reinterpret_cast with more than one offset should have been "
2073            "rejected by the verifier");
2074     OpFoldResult extractOffsetOfr =
2075         extractStridedMetadata.getConstifiedMixedOffset();
2076     OpFoldResult reinterpretOffsetOfr = op.getConstifiedMixedOffset();
2077     if (extractOffsetOfr != reinterpretOffsetOfr)
2078       return failure();
2079 
2080     // At this point, we know that the back and forth between extract strided
2081     // metadata and reinterpret cast is a noop. However, the final type of the
2082     // reinterpret cast may not be exactly the same as the original memref.
2083     // E.g., it could be changing a dimension from static to dynamic. Check that
2084     // here and add a cast if necessary.
2085     Type srcTy = extractStridedMetadata.getSource().getType();
2086     if (srcTy == op.getResult().getType())
2087       rewriter.replaceOp(op, extractStridedMetadata.getSource());
2088     else
2089       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2090                                           extractStridedMetadata.getSource());
2091 
2092     return success();
2093   }
2094 };
2095 } // namespace
2096 
2097 void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2098                                                     MLIRContext *context) {
2099   results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2100 }
2101 
2102 //===----------------------------------------------------------------------===//
2103 // Reassociative reshape ops
2104 //===----------------------------------------------------------------------===//
2105 
2106 void CollapseShapeOp::getAsmResultNames(
2107     function_ref<void(Value, StringRef)> setNameFn) {
2108   setNameFn(getResult(), "collapse_shape");
2109 }
2110 
2111 void ExpandShapeOp::getAsmResultNames(
2112     function_ref<void(Value, StringRef)> setNameFn) {
2113   setNameFn(getResult(), "expand_shape");
2114 }
2115 
2116 LogicalResult ExpandShapeOp::reifyResultShapes(
2117     OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2118   reifiedResultShapes = {
2119       getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2120   return success();
2121 }
2122 
2123 /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2124 /// result and operand. Layout maps are verified separately.
2125 ///
2126 /// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2127 /// allowed in a reassocation group.
2128 static LogicalResult
2129 verifyCollapsedShape(Operation *op, ArrayRef<int64_t> collapsedShape,
2130                      ArrayRef<int64_t> expandedShape,
2131                      ArrayRef<ReassociationIndices> reassociation,
2132                      bool allowMultipleDynamicDimsPerGroup) {
2133   // There must be one reassociation group per collapsed dimension.
2134   if (collapsedShape.size() != reassociation.size())
2135     return op->emitOpError("invalid number of reassociation groups: found ")
2136            << reassociation.size() << ", expected " << collapsedShape.size();
2137 
2138   // The next expected expanded dimension index (while iterating over
2139   // reassociation indices).
2140   int64_t nextDim = 0;
2141   for (const auto &it : llvm::enumerate(reassociation)) {
2142     ReassociationIndices group = it.value();
2143     int64_t collapsedDim = it.index();
2144 
2145     bool foundDynamic = false;
2146     for (int64_t expandedDim : group) {
2147       if (expandedDim != nextDim++)
2148         return op->emitOpError("reassociation indices must be contiguous");
2149 
2150       if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2151         return op->emitOpError("reassociation index ")
2152                << expandedDim << " is out of bounds";
2153 
2154       // Check if there are multiple dynamic dims in a reassociation group.
2155       if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2156         if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2157           return op->emitOpError(
2158               "at most one dimension in a reassociation group may be dynamic");
2159         foundDynamic = true;
2160       }
2161     }
2162 
2163     // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2164     if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2165       return op->emitOpError("collapsed dim (")
2166              << collapsedDim
2167              << ") must be dynamic if and only if reassociation group is "
2168                 "dynamic";
2169 
2170     // If all dims in the reassociation group are static, the size of the
2171     // collapsed dim can be verified.
2172     if (!foundDynamic) {
2173       int64_t groupSize = 1;
2174       for (int64_t expandedDim : group)
2175         groupSize *= expandedShape[expandedDim];
2176       if (groupSize != collapsedShape[collapsedDim])
2177         return op->emitOpError("collapsed dim size (")
2178                << collapsedShape[collapsedDim]
2179                << ") must equal reassociation group size (" << groupSize << ")";
2180     }
2181   }
2182 
2183   if (collapsedShape.empty()) {
2184     // Rank 0: All expanded dimensions must be 1.
2185     for (int64_t d : expandedShape)
2186       if (d != 1)
2187         return op->emitOpError(
2188             "rank 0 memrefs can only be extended/collapsed with/from ones");
2189   } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2190     // Rank >= 1: Number of dimensions among all reassociation groups must match
2191     // the result memref rank.
2192     return op->emitOpError("expanded rank (")
2193            << expandedShape.size()
2194            << ") inconsistent with number of reassociation indices (" << nextDim
2195            << ")";
2196   }
2197 
2198   return success();
2199 }
2200 
2201 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2202   return getSymbolLessAffineMaps(getReassociationExprs());
2203 }
2204 
2205 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2206   return convertReassociationIndicesToExprs(getContext(),
2207                                             getReassociationIndices());
2208 }
2209 
2210 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2211   return getSymbolLessAffineMaps(getReassociationExprs());
2212 }
2213 
2214 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2215   return convertReassociationIndicesToExprs(getContext(),
2216                                             getReassociationIndices());
2217 }
2218 
2219 /// Compute the layout map after expanding a given source MemRef type with the
2220 /// specified reassociation indices.
2221 static FailureOr<StridedLayoutAttr>
2222 computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2223                          ArrayRef<ReassociationIndices> reassociation) {
2224   int64_t srcOffset;
2225   SmallVector<int64_t> srcStrides;
2226   if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2227     return failure();
2228   assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2229 
2230   // 1-1 mapping between srcStrides and reassociation packs.
2231   // Each srcStride starts with the given value and gets expanded according to
2232   // the proper entries in resultShape.
2233   // Example:
2234   //   srcStrides     =                   [10000,  1 ,    100   ],
2235   //   reassociations =                   [  [0], [1], [2, 3, 4]],
2236   //   resultSizes    = [2, 5, 4, 3, 2] = [  [2], [5], [4, 3, 2]]
2237   //     -> For the purpose of stride calculation, the useful sizes are:
2238   //                    [x, x, x, 3, 2] = [  [x], [x], [x, 3, 2]].
2239   //   resultStrides = [10000, 1, 600, 200, 100]
2240   // Note that a stride does not get expanded along the first entry of each
2241   // shape pack.
2242   SmallVector<int64_t> reverseResultStrides;
2243   reverseResultStrides.reserve(resultShape.size());
2244   unsigned shapeIndex = resultShape.size() - 1;
2245   for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2246     ReassociationIndices reassoc = std::get<0>(it);
2247     int64_t currentStrideToExpand = std::get<1>(it);
2248     for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2249       reverseResultStrides.push_back(currentStrideToExpand);
2250       currentStrideToExpand =
2251           (SaturatedInteger::wrap(currentStrideToExpand) *
2252            SaturatedInteger::wrap(resultShape[shapeIndex--]))
2253               .asInteger();
2254     }
2255   }
2256   auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2257   resultStrides.resize(resultShape.size(), 1);
2258   return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2259 }
2260 
2261 FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2262     MemRefType srcType, ArrayRef<int64_t> resultShape,
2263     ArrayRef<ReassociationIndices> reassociation) {
2264   if (srcType.getLayout().isIdentity()) {
2265     // If the source is contiguous (i.e., no layout map specified), so is the
2266     // result.
2267     MemRefLayoutAttrInterface layout;
2268     return MemRefType::get(resultShape, srcType.getElementType(), layout,
2269                            srcType.getMemorySpace());
2270   }
2271 
2272   // Source may not be contiguous. Compute the layout map.
2273   FailureOr<StridedLayoutAttr> computedLayout =
2274       computeExpandedLayoutMap(srcType, resultShape, reassociation);
2275   if (failed(computedLayout))
2276     return failure();
2277   return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2278                          srcType.getMemorySpace());
2279 }
2280 
2281 FailureOr<SmallVector<OpFoldResult>>
2282 ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2283                                 MemRefType expandedType,
2284                                 ArrayRef<ReassociationIndices> reassociation,
2285                                 ArrayRef<OpFoldResult> inputShape) {
2286   std::optional<SmallVector<OpFoldResult>> outputShape =
2287       inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2288                                   inputShape);
2289   if (!outputShape)
2290     return failure();
2291   return *outputShape;
2292 }
2293 
2294 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2295                           Type resultType, Value src,
2296                           ArrayRef<ReassociationIndices> reassociation,
2297                           ArrayRef<OpFoldResult> outputShape) {
2298   auto [staticOutputShape, dynamicOutputShape] =
2299       decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
2300   build(builder, result, llvm::cast<MemRefType>(resultType), src,
2301         getReassociationIndicesAttribute(builder, reassociation),
2302         dynamicOutputShape, staticOutputShape);
2303 }
2304 
2305 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2306                           Type resultType, Value src,
2307                           ArrayRef<ReassociationIndices> reassociation) {
2308   SmallVector<OpFoldResult> inputShape =
2309       getMixedSizes(builder, result.location, src);
2310   MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2311   FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2312       builder, result.location, memrefResultTy, reassociation, inputShape);
2313   // Failure of this assertion usually indicates presence of multiple
2314   // dynamic dimensions in the same reassociation group.
2315   assert(succeeded(outputShape) && "unable to infer output shape");
2316   build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2317 }
2318 
2319 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2320                           ArrayRef<int64_t> resultShape, Value src,
2321                           ArrayRef<ReassociationIndices> reassociation) {
2322   // Only ranked memref source values are supported.
2323   auto srcType = llvm::cast<MemRefType>(src.getType());
2324   FailureOr<MemRefType> resultType =
2325       ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2326   // Failure of this assertion usually indicates a problem with the source
2327   // type, e.g., could not get strides/offset.
2328   assert(succeeded(resultType) && "could not compute layout");
2329   build(builder, result, *resultType, src, reassociation);
2330 }
2331 
2332 void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2333                           ArrayRef<int64_t> resultShape, Value src,
2334                           ArrayRef<ReassociationIndices> reassociation,
2335                           ArrayRef<OpFoldResult> outputShape) {
2336   // Only ranked memref source values are supported.
2337   auto srcType = llvm::cast<MemRefType>(src.getType());
2338   FailureOr<MemRefType> resultType =
2339       ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2340   // Failure of this assertion usually indicates a problem with the source
2341   // type, e.g., could not get strides/offset.
2342   assert(succeeded(resultType) && "could not compute layout");
2343   build(builder, result, *resultType, src, reassociation, outputShape);
2344 }
2345 
2346 LogicalResult ExpandShapeOp::verify() {
2347   MemRefType srcType = getSrcType();
2348   MemRefType resultType = getResultType();
2349 
2350   if (srcType.getRank() > resultType.getRank()) {
2351     auto r0 = srcType.getRank();
2352     auto r1 = resultType.getRank();
2353     return emitOpError("has source rank ")
2354            << r0 << " and result rank " << r1 << ". This is not an expansion ("
2355            << r0 << " > " << r1 << ").";
2356   }
2357 
2358   // Verify result shape.
2359   if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2360                                   resultType.getShape(),
2361                                   getReassociationIndices(),
2362                                   /*allowMultipleDynamicDimsPerGroup=*/true)))
2363     return failure();
2364 
2365   // Compute expected result type (including layout map).
2366   FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2367       srcType, resultType.getShape(), getReassociationIndices());
2368   if (failed(expectedResultType))
2369     return emitOpError("invalid source layout map");
2370 
2371   // Check actual result type.
2372   if (*expectedResultType != resultType)
2373     return emitOpError("expected expanded type to be ")
2374            << *expectedResultType << " but found " << resultType;
2375 
2376   if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2377     return emitOpError("expected number of static shape bounds to be equal to "
2378                        "the output rank (")
2379            << resultType.getRank() << ") but found "
2380            << getStaticOutputShape().size() << " inputs instead";
2381 
2382   if ((int64_t)getOutputShape().size() !=
2383       llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2384     return emitOpError("mismatch in dynamic dims in output_shape and "
2385                        "static_output_shape: static_output_shape has ")
2386            << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2387            << " dynamic dims while output_shape has " << getOutputShape().size()
2388            << " values";
2389 
2390   // Verify if provided output shapes are in agreement with output type.
2391   DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2392   ArrayRef<int64_t> resShape = getResult().getType().getShape();
2393   for (auto [pos, shape] : llvm::enumerate(resShape)) {
2394     if (!ShapedType::isDynamic(shape) && shape != staticOutputShapes[pos]) {
2395       return emitOpError("invalid output shape provided at pos ") << pos;
2396     }
2397   }
2398 
2399   return success();
2400 }
2401 
2402 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2403                                                 MLIRContext *context) {
2404   results.add<
2405       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2406       ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
2407 }
2408 
2409 /// Compute the layout map after collapsing a given source MemRef type with the
2410 /// specified reassociation indices.
2411 ///
2412 /// Note: All collapsed dims in a reassociation group must be contiguous. It is
2413 /// not possible to check this by inspecting a MemRefType in the general case.
2414 /// If non-contiguity cannot be checked statically, the collapse is assumed to
2415 /// be valid (and thus accepted by this function) unless `strict = true`.
2416 static FailureOr<StridedLayoutAttr>
2417 computeCollapsedLayoutMap(MemRefType srcType,
2418                           ArrayRef<ReassociationIndices> reassociation,
2419                           bool strict = false) {
2420   int64_t srcOffset;
2421   SmallVector<int64_t> srcStrides;
2422   auto srcShape = srcType.getShape();
2423   if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2424     return failure();
2425 
2426   // The result stride of a reassociation group is the stride of the last entry
2427   // of the reassociation. (TODO: Should be the minimum stride in the
2428   // reassociation because strides are not necessarily sorted. E.g., when using
2429   // memref.transpose.) Dimensions of size 1 should be skipped, because their
2430   // strides are meaningless and could have any arbitrary value.
2431   SmallVector<int64_t> resultStrides;
2432   resultStrides.reserve(reassociation.size());
2433   for (const ReassociationIndices &reassoc : reassociation) {
2434     ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2435     while (srcShape[ref.back()] == 1 && ref.size() > 1)
2436       ref = ref.drop_back();
2437     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
2438       resultStrides.push_back(srcStrides[ref.back()]);
2439     } else {
2440       // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2441       // the corresponding stride may have to be skipped. (See above comment.)
2442       // Therefore, the result stride cannot be statically determined and must
2443       // be dynamic.
2444       resultStrides.push_back(ShapedType::kDynamic);
2445     }
2446   }
2447 
2448   // Validate that each reassociation group is contiguous.
2449   unsigned resultStrideIndex = resultStrides.size() - 1;
2450   for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2451     auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2452     auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2453     for (int64_t idx : llvm::reverse(trailingReassocs)) {
2454       stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2455 
2456       // Both source and result stride must have the same static value. In that
2457       // case, we can be sure, that the dimensions are collapsible (because they
2458       // are contiguous).
2459       // If `strict = false` (default during op verification), we accept cases
2460       // where one or both strides are dynamic. This is best effort: We reject
2461       // ops where obviously non-contiguous dims are collapsed, but accept ops
2462       // where we cannot be sure statically. Such ops may fail at runtime. See
2463       // the op documentation for details.
2464       auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2465       if (strict && (stride.saturated || srcStride.saturated))
2466         return failure();
2467 
2468       // Dimensions of size 1 should be skipped, because their strides are
2469       // meaningless and could have any arbitrary value.
2470       if (srcShape[idx - 1] == 1)
2471         continue;
2472 
2473       if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2474         return failure();
2475     }
2476   }
2477   return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2478 }
2479 
2480 bool CollapseShapeOp::isGuaranteedCollapsible(
2481     MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2482   // MemRefs with identity layout are always collapsible.
2483   if (srcType.getLayout().isIdentity())
2484     return true;
2485 
2486   return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2487                                              /*strict=*/true));
2488 }
2489 
2490 MemRefType CollapseShapeOp::computeCollapsedType(
2491     MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2492   SmallVector<int64_t> resultShape;
2493   resultShape.reserve(reassociation.size());
2494   for (const ReassociationIndices &group : reassociation) {
2495     auto groupSize = SaturatedInteger::wrap(1);
2496     for (int64_t srcDim : group)
2497       groupSize =
2498           groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2499     resultShape.push_back(groupSize.asInteger());
2500   }
2501 
2502   if (srcType.getLayout().isIdentity()) {
2503     // If the source is contiguous (i.e., no layout map specified), so is the
2504     // result.
2505     MemRefLayoutAttrInterface layout;
2506     return MemRefType::get(resultShape, srcType.getElementType(), layout,
2507                            srcType.getMemorySpace());
2508   }
2509 
2510   // Source may not be fully contiguous. Compute the layout map.
2511   // Note: Dimensions that are collapsed into a single dim are assumed to be
2512   // contiguous.
2513   FailureOr<StridedLayoutAttr> computedLayout =
2514       computeCollapsedLayoutMap(srcType, reassociation);
2515   assert(succeeded(computedLayout) &&
2516          "invalid source layout map or collapsing non-contiguous dims");
2517   return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2518                          srcType.getMemorySpace());
2519 }
2520 
2521 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2522                             ArrayRef<ReassociationIndices> reassociation,
2523                             ArrayRef<NamedAttribute> attrs) {
2524   auto srcType = llvm::cast<MemRefType>(src.getType());
2525   MemRefType resultType =
2526       CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2527   result.addAttribute(::mlir::getReassociationAttrName(),
2528                       getReassociationIndicesAttribute(b, reassociation));
2529   build(b, result, resultType, src, attrs);
2530 }
2531 
2532 LogicalResult CollapseShapeOp::verify() {
2533   MemRefType srcType = getSrcType();
2534   MemRefType resultType = getResultType();
2535 
2536   if (srcType.getRank() < resultType.getRank()) {
2537     auto r0 = srcType.getRank();
2538     auto r1 = resultType.getRank();
2539     return emitOpError("has source rank ")
2540            << r0 << " and result rank " << r1 << ". This is not a collapse ("
2541            << r0 << " < " << r1 << ").";
2542   }
2543 
2544   // Verify result shape.
2545   if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2546                                   srcType.getShape(), getReassociationIndices(),
2547                                   /*allowMultipleDynamicDimsPerGroup=*/true)))
2548     return failure();
2549 
2550   // Compute expected result type (including layout map).
2551   MemRefType expectedResultType;
2552   if (srcType.getLayout().isIdentity()) {
2553     // If the source is contiguous (i.e., no layout map specified), so is the
2554     // result.
2555     MemRefLayoutAttrInterface layout;
2556     expectedResultType =
2557         MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2558                         srcType.getMemorySpace());
2559   } else {
2560     // Source may not be fully contiguous. Compute the layout map.
2561     // Note: Dimensions that are collapsed into a single dim are assumed to be
2562     // contiguous.
2563     FailureOr<StridedLayoutAttr> computedLayout =
2564         computeCollapsedLayoutMap(srcType, getReassociationIndices());
2565     if (failed(computedLayout))
2566       return emitOpError(
2567           "invalid source layout map or collapsing non-contiguous dims");
2568     expectedResultType =
2569         MemRefType::get(resultType.getShape(), srcType.getElementType(),
2570                         *computedLayout, srcType.getMemorySpace());
2571   }
2572 
2573   if (expectedResultType != resultType)
2574     return emitOpError("expected collapsed type to be ")
2575            << expectedResultType << " but found " << resultType;
2576 
2577   return success();
2578 }
2579 
2580 struct CollapseShapeOpMemRefCastFolder
2581     : public OpRewritePattern<CollapseShapeOp> {
2582 public:
2583   using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2584 
2585   LogicalResult matchAndRewrite(CollapseShapeOp op,
2586                                 PatternRewriter &rewriter) const override {
2587     auto cast = op.getOperand().getDefiningOp<CastOp>();
2588     if (!cast)
2589       return failure();
2590 
2591     if (!CastOp::canFoldIntoConsumerOp(cast))
2592       return failure();
2593 
2594     Type newResultType = CollapseShapeOp::computeCollapsedType(
2595         llvm::cast<MemRefType>(cast.getOperand().getType()),
2596         op.getReassociationIndices());
2597 
2598     if (newResultType == op.getResultType()) {
2599       rewriter.modifyOpInPlace(
2600           op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2601     } else {
2602       Value newOp = rewriter.create<CollapseShapeOp>(
2603           op->getLoc(), cast.getSource(), op.getReassociationIndices());
2604       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2605     }
2606     return success();
2607   }
2608 };
2609 
2610 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2611                                                   MLIRContext *context) {
2612   results.add<
2613       ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2614       ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2615                                 memref::DimOp, MemRefType>,
2616       CollapseShapeOpMemRefCastFolder>(context);
2617 }
2618 
2619 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2620   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
2621                                                        adaptor.getOperands());
2622 }
2623 
2624 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2625   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
2626                                                        adaptor.getOperands());
2627 }
2628 
2629 //===----------------------------------------------------------------------===//
2630 // ReshapeOp
2631 //===----------------------------------------------------------------------===//
2632 
2633 void ReshapeOp::getAsmResultNames(
2634     function_ref<void(Value, StringRef)> setNameFn) {
2635   setNameFn(getResult(), "reshape");
2636 }
2637 
2638 LogicalResult ReshapeOp::verify() {
2639   Type operandType = getSource().getType();
2640   Type resultType = getResult().getType();
2641 
2642   Type operandElementType =
2643       llvm::cast<ShapedType>(operandType).getElementType();
2644   Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2645   if (operandElementType != resultElementType)
2646     return emitOpError("element types of source and destination memref "
2647                        "types should be the same");
2648 
2649   if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2650     if (!operandMemRefType.getLayout().isIdentity())
2651       return emitOpError("source memref type should have identity affine map");
2652 
2653   int64_t shapeSize =
2654       llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2655   auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2656   if (resultMemRefType) {
2657     if (!resultMemRefType.getLayout().isIdentity())
2658       return emitOpError("result memref type should have identity affine map");
2659     if (shapeSize == ShapedType::kDynamic)
2660       return emitOpError("cannot use shape operand with dynamic length to "
2661                          "reshape to statically-ranked memref type");
2662     if (shapeSize != resultMemRefType.getRank())
2663       return emitOpError(
2664           "length of shape operand differs from the result's memref rank");
2665   }
2666   return success();
2667 }
2668 
2669 //===----------------------------------------------------------------------===//
2670 // StoreOp
2671 //===----------------------------------------------------------------------===//
2672 
2673 LogicalResult StoreOp::verify() {
2674   if (getNumOperands() != 2 + getMemRefType().getRank())
2675     return emitOpError("store index operand count not equal to memref rank");
2676 
2677   return success();
2678 }
2679 
2680 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2681                             SmallVectorImpl<OpFoldResult> &results) {
2682   /// store(memrefcast) -> store
2683   return foldMemRefCast(*this, getValueToStore());
2684 }
2685 
2686 //===----------------------------------------------------------------------===//
2687 // SubViewOp
2688 //===----------------------------------------------------------------------===//
2689 
2690 void SubViewOp::getAsmResultNames(
2691     function_ref<void(Value, StringRef)> setNameFn) {
2692   setNameFn(getResult(), "subview");
2693 }
2694 
2695 /// A subview result type can be fully inferred from the source type and the
2696 /// static representation of offsets, sizes and strides. Special sentinels
2697 /// encode the dynamic case.
2698 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2699                                 ArrayRef<int64_t> staticOffsets,
2700                                 ArrayRef<int64_t> staticSizes,
2701                                 ArrayRef<int64_t> staticStrides) {
2702   unsigned rank = sourceMemRefType.getRank();
2703   (void)rank;
2704   assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2705   assert(staticSizes.size() == rank && "staticSizes length mismatch");
2706   assert(staticStrides.size() == rank && "staticStrides length mismatch");
2707 
2708   // Extract source offset and strides.
2709   auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2710 
2711   // Compute target offset whose value is:
2712   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2713   int64_t targetOffset = sourceOffset;
2714   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2715     auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2716     targetOffset = (SaturatedInteger::wrap(targetOffset) +
2717                     SaturatedInteger::wrap(staticOffset) *
2718                         SaturatedInteger::wrap(sourceStride))
2719                        .asInteger();
2720   }
2721 
2722   // Compute target stride whose value is:
2723   //   `sourceStrides_i * staticStrides_i`.
2724   SmallVector<int64_t, 4> targetStrides;
2725   targetStrides.reserve(staticOffsets.size());
2726   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2727     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2728     targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2729                              SaturatedInteger::wrap(staticStride))
2730                                 .asInteger());
2731   }
2732 
2733   // The type is now known.
2734   return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2735                          StridedLayoutAttr::get(sourceMemRefType.getContext(),
2736                                                 targetOffset, targetStrides),
2737                          sourceMemRefType.getMemorySpace());
2738 }
2739 
2740 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
2741                                 ArrayRef<OpFoldResult> offsets,
2742                                 ArrayRef<OpFoldResult> sizes,
2743                                 ArrayRef<OpFoldResult> strides) {
2744   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2745   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2746   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2747   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2748   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2749   if (!hasValidSizesOffsets(staticOffsets))
2750     return {};
2751   if (!hasValidSizesOffsets(staticSizes))
2752     return {};
2753   if (!hasValidStrides(staticStrides))
2754     return {};
2755   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2756                                     staticSizes, staticStrides);
2757 }
2758 
2759 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2760                                            MemRefType sourceRankedTensorType,
2761                                            ArrayRef<int64_t> offsets,
2762                                            ArrayRef<int64_t> sizes,
2763                                            ArrayRef<int64_t> strides) {
2764   auto inferredType = llvm::cast<MemRefType>(
2765       inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2766   assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2767          "expected ");
2768   if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2769     return inferredType;
2770 
2771   // Compute which dimensions are dropped.
2772   std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2773       computeRankReductionMask(inferredType.getShape(), resultShape);
2774   assert(dimsToProject.has_value() && "invalid rank reduction");
2775 
2776   // Compute the layout and result type.
2777   auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2778   SmallVector<int64_t> rankReducedStrides;
2779   rankReducedStrides.reserve(resultShape.size());
2780   for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2781     if (!dimsToProject->contains(idx))
2782       rankReducedStrides.push_back(value);
2783   }
2784   return MemRefType::get(resultShape, inferredType.getElementType(),
2785                          StridedLayoutAttr::get(inferredLayout.getContext(),
2786                                                 inferredLayout.getOffset(),
2787                                                 rankReducedStrides),
2788                          inferredType.getMemorySpace());
2789 }
2790 
2791 Type SubViewOp::inferRankReducedResultType(ArrayRef<int64_t> resultShape,
2792                                            MemRefType sourceRankedTensorType,
2793                                            ArrayRef<OpFoldResult> offsets,
2794                                            ArrayRef<OpFoldResult> sizes,
2795                                            ArrayRef<OpFoldResult> strides) {
2796   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2797   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2798   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2799   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2800   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2801   return SubViewOp::inferRankReducedResultType(
2802       resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2803       staticStrides);
2804 }
2805 
2806 // Build a SubViewOp with mixed static and dynamic entries and custom result
2807 // type. If the type passed is nullptr, it is inferred.
2808 void SubViewOp::build(OpBuilder &b, OperationState &result,
2809                       MemRefType resultType, Value source,
2810                       ArrayRef<OpFoldResult> offsets,
2811                       ArrayRef<OpFoldResult> sizes,
2812                       ArrayRef<OpFoldResult> strides,
2813                       ArrayRef<NamedAttribute> attrs) {
2814   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2815   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2816   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2817   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2818   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2819   auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2820   // Structuring implementation this way avoids duplication between builders.
2821   if (!resultType) {
2822     resultType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
2823         sourceMemRefType, staticOffsets, staticSizes, staticStrides));
2824   }
2825   result.addAttributes(attrs);
2826   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2827         dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2828         b.getDenseI64ArrayAttr(staticSizes),
2829         b.getDenseI64ArrayAttr(staticStrides));
2830 }
2831 
2832 // Build a SubViewOp with mixed static and dynamic entries and inferred result
2833 // type.
2834 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2835                       ArrayRef<OpFoldResult> offsets,
2836                       ArrayRef<OpFoldResult> sizes,
2837                       ArrayRef<OpFoldResult> strides,
2838                       ArrayRef<NamedAttribute> attrs) {
2839   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2840 }
2841 
2842 // Build a SubViewOp with static entries and inferred result type.
2843 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2844                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2845                       ArrayRef<int64_t> strides,
2846                       ArrayRef<NamedAttribute> attrs) {
2847   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2848       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2849         return b.getI64IntegerAttr(v);
2850       }));
2851   SmallVector<OpFoldResult> sizeValues =
2852       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2853         return b.getI64IntegerAttr(v);
2854       }));
2855   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2856       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2857         return b.getI64IntegerAttr(v);
2858       }));
2859   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2860 }
2861 
2862 // Build a SubViewOp with dynamic entries and custom result type. If the
2863 // type passed is nullptr, it is inferred.
2864 void SubViewOp::build(OpBuilder &b, OperationState &result,
2865                       MemRefType resultType, Value source,
2866                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2867                       ArrayRef<int64_t> strides,
2868                       ArrayRef<NamedAttribute> attrs) {
2869   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2870       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2871         return b.getI64IntegerAttr(v);
2872       }));
2873   SmallVector<OpFoldResult> sizeValues =
2874       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2875         return b.getI64IntegerAttr(v);
2876       }));
2877   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2878       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2879         return b.getI64IntegerAttr(v);
2880       }));
2881   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
2882         attrs);
2883 }
2884 
2885 // Build a SubViewOp with dynamic entries and custom result type. If the type
2886 // passed is nullptr, it is inferred.
2887 void SubViewOp::build(OpBuilder &b, OperationState &result,
2888                       MemRefType resultType, Value source, ValueRange offsets,
2889                       ValueRange sizes, ValueRange strides,
2890                       ArrayRef<NamedAttribute> attrs) {
2891   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2892       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
2893   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2894       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
2895   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2896       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
2897   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2898 }
2899 
2900 // Build a SubViewOp with dynamic entries and inferred result type.
2901 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2902                       ValueRange offsets, ValueRange sizes, ValueRange strides,
2903                       ArrayRef<NamedAttribute> attrs) {
2904   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2905 }
2906 
2907 /// For ViewLikeOpInterface.
2908 Value SubViewOp::getViewSource() { return getSource(); }
2909 
2910 /// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
2911 /// static value).
2912 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
2913   int64_t t1Offset, t2Offset;
2914   SmallVector<int64_t> t1Strides, t2Strides;
2915   auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2916   auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2917   return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
2918 }
2919 
2920 /// Return true if `t1` and `t2` have equal strides (both dynamic or of same
2921 /// static value). Dimensions of `t1` may be dropped in `t2`; these must be
2922 /// marked as dropped in `droppedDims`.
2923 static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
2924                                   const llvm::SmallBitVector &droppedDims) {
2925   assert(size_t(t1.getRank()) == droppedDims.size() &&
2926          "incorrect number of bits");
2927   assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
2928          "incorrect number of dropped dims");
2929   int64_t t1Offset, t2Offset;
2930   SmallVector<int64_t> t1Strides, t2Strides;
2931   auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
2932   auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
2933   if (failed(res1) || failed(res2))
2934     return false;
2935   for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
2936     if (droppedDims[i])
2937       continue;
2938     if (t1Strides[i] != t2Strides[j])
2939       return false;
2940     ++j;
2941   }
2942   return true;
2943 }
2944 
2945 static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
2946                                             Operation *op, Type expectedType) {
2947   auto memrefType = llvm::cast<ShapedType>(expectedType);
2948   switch (result) {
2949   case SliceVerificationResult::Success:
2950     return success();
2951   case SliceVerificationResult::RankTooLarge:
2952     return op->emitError("expected result rank to be smaller or equal to ")
2953            << "the source rank. ";
2954   case SliceVerificationResult::SizeMismatch:
2955     return op->emitError("expected result type to be ")
2956            << expectedType
2957            << " or a rank-reduced version. (mismatch of result sizes) ";
2958   case SliceVerificationResult::ElemTypeMismatch:
2959     return op->emitError("expected result element type to be ")
2960            << memrefType.getElementType();
2961   case SliceVerificationResult::MemSpaceMismatch:
2962     return op->emitError("expected result and source memory spaces to match.");
2963   case SliceVerificationResult::LayoutMismatch:
2964     return op->emitError("expected result type to be ")
2965            << expectedType
2966            << " or a rank-reduced version. (mismatch of result layout) ";
2967   }
2968   llvm_unreachable("unexpected subview verification result");
2969 }
2970 
2971 /// Verifier for SubViewOp.
2972 LogicalResult SubViewOp::verify() {
2973   MemRefType baseType = getSourceType();
2974   MemRefType subViewType = getType();
2975 
2976   // The base memref and the view memref should be in the same memory space.
2977   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2978     return emitError("different memory spaces specified for base memref "
2979                      "type ")
2980            << baseType << " and subview memref type " << subViewType;
2981 
2982   // Verify that the base memref type has a strided layout map.
2983   if (!baseType.isStrided())
2984     return emitError("base type ") << baseType << " is not strided";
2985 
2986   // Compute the expected result type, assuming that there are no rank
2987   // reductions.
2988   auto expectedType = cast<MemRefType>(SubViewOp::inferResultType(
2989       baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()));
2990 
2991   // Verify all properties of a shaped type: rank, element type and dimension
2992   // sizes. This takes into account potential rank reductions.
2993   auto shapedTypeVerification = isRankReducedType(
2994       /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
2995   if (shapedTypeVerification != SliceVerificationResult::Success)
2996     return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
2997 
2998   // Make sure that the memory space did not change.
2999   if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3000     return produceSubViewErrorMsg(SliceVerificationResult::MemSpaceMismatch,
3001                                   *this, expectedType);
3002 
3003   // Verify the offset of the layout map.
3004   if (!haveCompatibleOffsets(expectedType, subViewType))
3005     return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
3006                                   *this, expectedType);
3007 
3008   // The only thing that's left to verify now are the strides. First, compute
3009   // the unused dimensions due to rank reductions. We have to look at sizes and
3010   // strides to decide which dimensions were dropped. This function also
3011   // partially verifies strides in case of rank reductions.
3012   auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3013                                                    getMixedSizes());
3014   if (failed(unusedDims))
3015     return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
3016                                   *this, expectedType);
3017 
3018   // Strides must match.
3019   if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3020     return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
3021                                   *this, expectedType);
3022 
3023   return success();
3024 }
3025 
3026 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
3027   return os << "range " << range.offset << ":" << range.size << ":"
3028             << range.stride;
3029 }
3030 
3031 /// Return the list of Range (i.e. offset, size, stride). Each Range
3032 /// entry contains either the dynamic value or a ConstantIndexOp constructed
3033 /// with `b` at location `loc`.
3034 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3035                                               OpBuilder &b, Location loc) {
3036   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3037   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3038   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3039   SmallVector<Range, 8> res;
3040   unsigned rank = ranks[0];
3041   res.reserve(rank);
3042   for (unsigned idx = 0; idx < rank; ++idx) {
3043     Value offset =
3044         op.isDynamicOffset(idx)
3045             ? op.getDynamicOffset(idx)
3046             : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx));
3047     Value size =
3048         op.isDynamicSize(idx)
3049             ? op.getDynamicSize(idx)
3050             : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx));
3051     Value stride =
3052         op.isDynamicStride(idx)
3053             ? op.getDynamicStride(idx)
3054             : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx));
3055     res.emplace_back(Range{offset, size, stride});
3056   }
3057   return res;
3058 }
3059 
3060 /// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3061 /// to deduce the result type for the given `sourceType`. Additionally, reduce
3062 /// the rank of the inferred result type if `currentResultType` is lower rank
3063 /// than `currentSourceType`. Use this signature if `sourceType` is updated
3064 /// together with the result type. In this case, it is important to compute
3065 /// the dropped dimensions using `currentSourceType` whose strides align with
3066 /// `currentResultType`.
3067 static MemRefType getCanonicalSubViewResultType(
3068     MemRefType currentResultType, MemRefType currentSourceType,
3069     MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3070     ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3071   auto nonRankReducedType = llvm::cast<MemRefType>(SubViewOp::inferResultType(
3072       sourceType, mixedOffsets, mixedSizes, mixedStrides));
3073   FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3074       currentSourceType, currentResultType, mixedSizes);
3075   if (failed(unusedDims))
3076     return nullptr;
3077 
3078   auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3079   SmallVector<int64_t> shape, strides;
3080   unsigned numDimsAfterReduction =
3081       nonRankReducedType.getRank() - unusedDims->count();
3082   shape.reserve(numDimsAfterReduction);
3083   strides.reserve(numDimsAfterReduction);
3084   for (const auto &[idx, size, stride] :
3085        llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3086                  nonRankReducedType.getShape(), layout.getStrides())) {
3087     if (unusedDims->test(idx))
3088       continue;
3089     shape.push_back(size);
3090     strides.push_back(stride);
3091   }
3092 
3093   return MemRefType::get(shape, nonRankReducedType.getElementType(),
3094                          StridedLayoutAttr::get(sourceType.getContext(),
3095                                                 layout.getOffset(), strides),
3096                          nonRankReducedType.getMemorySpace());
3097 }
3098 
3099 Value mlir::memref::createCanonicalRankReducingSubViewOp(
3100     OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3101   auto memrefType = llvm::cast<MemRefType>(memref.getType());
3102   unsigned rank = memrefType.getRank();
3103   SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3104   SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, memref);
3105   SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3106   auto targetType =
3107       llvm::cast<MemRefType>(SubViewOp::inferRankReducedResultType(
3108           targetShape, memrefType, offsets, sizes, strides));
3109   return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3110                                            sizes, strides);
3111 }
3112 
3113 FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3114                                                Value value,
3115                                                ArrayRef<int64_t> desiredShape) {
3116   auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3117   assert(sourceMemrefType && "not a ranked memref type");
3118   auto sourceShape = sourceMemrefType.getShape();
3119   if (sourceShape.equals(desiredShape))
3120     return value;
3121   auto maybeRankReductionMask =
3122       mlir::computeRankReductionMask(sourceShape, desiredShape);
3123   if (!maybeRankReductionMask)
3124     return failure();
3125   return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3126 }
3127 
3128 /// Helper method to check if a `subview` operation is trivially a no-op. This
3129 /// is the case if the all offsets are zero, all strides are 1, and the source
3130 /// shape is same as the size of the subview. In such cases, the subview can
3131 /// be folded into its source.
3132 static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3133   if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3134     return false;
3135 
3136   auto mixedOffsets = subViewOp.getMixedOffsets();
3137   auto mixedSizes = subViewOp.getMixedSizes();
3138   auto mixedStrides = subViewOp.getMixedStrides();
3139 
3140   // Check offsets are zero.
3141   if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3142         std::optional<int64_t> intValue = getConstantIntValue(ofr);
3143         return !intValue || intValue.value() != 0;
3144       }))
3145     return false;
3146 
3147   // Check strides are one.
3148   if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3149         std::optional<int64_t> intValue = getConstantIntValue(ofr);
3150         return !intValue || intValue.value() != 1;
3151       }))
3152     return false;
3153 
3154   // Check all size values are static and matches the (static) source shape.
3155   ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3156   for (const auto &size : llvm::enumerate(mixedSizes)) {
3157     std::optional<int64_t> intValue = getConstantIntValue(size.value());
3158     if (!intValue || *intValue != sourceShape[size.index()])
3159       return false;
3160   }
3161   // All conditions met. The `SubViewOp` is foldable as a no-op.
3162   return true;
3163 }
3164 
3165 namespace {
3166 /// Pattern to rewrite a subview op with MemRefCast arguments.
3167 /// This essentially pushes memref.cast past its consuming subview when
3168 /// `canFoldIntoConsumerOp` is true.
3169 ///
3170 /// Example:
3171 /// ```
3172 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3173 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3174 ///     memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3175 /// ```
3176 /// is rewritten into:
3177 /// ```
3178 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3179 ///   %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3180 ///     memref<3x4xf32, strided<[?, 1], offset: ?>>
3181 /// ```
3182 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3183 public:
3184   using OpRewritePattern<SubViewOp>::OpRewritePattern;
3185 
3186   LogicalResult matchAndRewrite(SubViewOp subViewOp,
3187                                 PatternRewriter &rewriter) const override {
3188     // Any constant operand, just return to let SubViewOpConstantFolder kick
3189     // in.
3190     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3191           return matchPattern(operand, matchConstantIndex());
3192         }))
3193       return failure();
3194 
3195     auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3196     if (!castOp)
3197       return failure();
3198 
3199     if (!CastOp::canFoldIntoConsumerOp(castOp))
3200       return failure();
3201 
3202     // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3203     // the MemRefCastOp source operand type to infer the result type and the
3204     // current SubViewOp source operand type to compute the dropped dimensions
3205     // if the operation is rank-reducing.
3206     auto resultType = getCanonicalSubViewResultType(
3207         subViewOp.getType(), subViewOp.getSourceType(),
3208         llvm::cast<MemRefType>(castOp.getSource().getType()),
3209         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3210         subViewOp.getMixedStrides());
3211     if (!resultType)
3212       return failure();
3213 
3214     Value newSubView = rewriter.create<SubViewOp>(
3215         subViewOp.getLoc(), resultType, castOp.getSource(),
3216         subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3217         subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3218         subViewOp.getStaticStrides());
3219     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3220                                         newSubView);
3221     return success();
3222   }
3223 };
3224 
3225 /// Canonicalize subview ops that are no-ops. When the source shape is not
3226 /// same as a result shape due to use of `affine_map`.
3227 class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3228 public:
3229   using OpRewritePattern<SubViewOp>::OpRewritePattern;
3230 
3231   LogicalResult matchAndRewrite(SubViewOp subViewOp,
3232                                 PatternRewriter &rewriter) const override {
3233     if (!isTrivialSubViewOp(subViewOp))
3234       return failure();
3235     if (subViewOp.getSourceType() == subViewOp.getType()) {
3236       rewriter.replaceOp(subViewOp, subViewOp.getSource());
3237       return success();
3238     }
3239     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3240                                         subViewOp.getSource());
3241     return success();
3242   }
3243 };
3244 } // namespace
3245 
3246 /// Return the canonical type of the result of a subview.
3247 struct SubViewReturnTypeCanonicalizer {
3248   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3249                         ArrayRef<OpFoldResult> mixedSizes,
3250                         ArrayRef<OpFoldResult> mixedStrides) {
3251     // Infer a memref type without taking into account any rank reductions.
3252     auto resTy = SubViewOp::inferResultType(op.getSourceType(), mixedOffsets,
3253                                             mixedSizes, mixedStrides);
3254     if (!resTy)
3255       return {};
3256     MemRefType nonReducedType = cast<MemRefType>(resTy);
3257 
3258     // Directly return the non-rank reduced type if there are no dropped dims.
3259     llvm::SmallBitVector droppedDims = op.getDroppedDims();
3260     if (droppedDims.none())
3261       return nonReducedType;
3262 
3263     // Take the strides and offset from the non-rank reduced type.
3264     auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3265 
3266     // Drop dims from shape and strides.
3267     SmallVector<int64_t> targetShape;
3268     SmallVector<int64_t> targetStrides;
3269     for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3270       if (droppedDims.test(i))
3271         continue;
3272       targetStrides.push_back(nonReducedStrides[i]);
3273       targetShape.push_back(nonReducedType.getDimSize(i));
3274     }
3275 
3276     return MemRefType::get(targetShape, nonReducedType.getElementType(),
3277                            StridedLayoutAttr::get(nonReducedType.getContext(),
3278                                                   offset, targetStrides),
3279                            nonReducedType.getMemorySpace());
3280   }
3281 };
3282 
3283 /// A canonicalizer wrapper to replace SubViewOps.
3284 struct SubViewCanonicalizer {
3285   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3286     rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3287   }
3288 };
3289 
3290 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3291                                             MLIRContext *context) {
3292   results
3293       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3294                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3295            SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3296 }
3297 
3298 OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3299   MemRefType sourceMemrefType = getSource().getType();
3300   MemRefType resultMemrefType = getResult().getType();
3301   auto resultLayout =
3302       dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3303 
3304   if (resultMemrefType == sourceMemrefType &&
3305       resultMemrefType.hasStaticShape() &&
3306       (!resultLayout || resultLayout.hasStaticLayout())) {
3307     return getViewSource();
3308   }
3309 
3310   // Fold subview(subview(x)), where both subviews have the same size and the
3311   // second subview's offsets are all zero. (I.e., the second subview is a
3312   // no-op.)
3313   if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3314     auto srcSizes = srcSubview.getMixedSizes();
3315     auto sizes = getMixedSizes();
3316     auto offsets = getMixedOffsets();
3317     bool allOffsetsZero = llvm::all_of(
3318         offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3319     auto strides = getMixedStrides();
3320     bool allStridesOne = llvm::all_of(
3321         strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3322     bool allSizesSame = llvm::equal(sizes, srcSizes);
3323     if (allOffsetsZero && allStridesOne && allSizesSame &&
3324         resultMemrefType == sourceMemrefType)
3325       return getViewSource();
3326   }
3327 
3328   return {};
3329 }
3330 
3331 //===----------------------------------------------------------------------===//
3332 // TransposeOp
3333 //===----------------------------------------------------------------------===//
3334 
3335 void TransposeOp::getAsmResultNames(
3336     function_ref<void(Value, StringRef)> setNameFn) {
3337   setNameFn(getResult(), "transpose");
3338 }
3339 
3340 /// Build a strided memref type by applying `permutationMap` to `memRefType`.
3341 static MemRefType inferTransposeResultType(MemRefType memRefType,
3342                                            AffineMap permutationMap) {
3343   auto originalSizes = memRefType.getShape();
3344   auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3345   assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3346 
3347   // Compute permuted sizes and strides.
3348   auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3349   auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3350 
3351   return MemRefType::Builder(memRefType)
3352       .setShape(sizes)
3353       .setLayout(
3354           StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3355 }
3356 
3357 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3358                         AffineMapAttr permutation,
3359                         ArrayRef<NamedAttribute> attrs) {
3360   auto permutationMap = permutation.getValue();
3361   assert(permutationMap);
3362 
3363   auto memRefType = llvm::cast<MemRefType>(in.getType());
3364   // Compute result type.
3365   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3366 
3367   result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3368   build(b, result, resultType, in, attrs);
3369 }
3370 
3371 // transpose $in $permutation attr-dict : type($in) `to` type(results)
3372 void TransposeOp::print(OpAsmPrinter &p) {
3373   p << " " << getIn() << " " << getPermutation();
3374   p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3375   p << " : " << getIn().getType() << " to " << getType();
3376 }
3377 
3378 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3379   OpAsmParser::UnresolvedOperand in;
3380   AffineMap permutation;
3381   MemRefType srcType, dstType;
3382   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3383       parser.parseOptionalAttrDict(result.attributes) ||
3384       parser.parseColonType(srcType) ||
3385       parser.resolveOperand(in, srcType, result.operands) ||
3386       parser.parseKeywordType("to", dstType) ||
3387       parser.addTypeToList(dstType, result.types))
3388     return failure();
3389 
3390   result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3391                       AffineMapAttr::get(permutation));
3392   return success();
3393 }
3394 
3395 LogicalResult TransposeOp::verify() {
3396   if (!getPermutation().isPermutation())
3397     return emitOpError("expected a permutation map");
3398   if (getPermutation().getNumDims() != getIn().getType().getRank())
3399     return emitOpError("expected a permutation map of same rank as the input");
3400 
3401   auto srcType = llvm::cast<MemRefType>(getIn().getType());
3402   auto resultType = llvm::cast<MemRefType>(getType());
3403   auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3404                                  .canonicalizeStridedLayout();
3405 
3406   if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3407     return emitOpError("result type ")
3408            << resultType
3409            << " is not equivalent to the canonical transposed input type "
3410            << canonicalResultType;
3411   return success();
3412 }
3413 
3414 OpFoldResult TransposeOp::fold(FoldAdaptor) {
3415   // First check for identity permutation, we can fold it away if input and
3416   // result types are identical already.
3417   if (getPermutation().isIdentity() && getType() == getIn().getType())
3418     return getIn();
3419   // Fold two consecutive memref.transpose Ops into one by composing their
3420   // permutation maps.
3421   if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3422     AffineMap composedPermutation =
3423         getPermutation().compose(otherTransposeOp.getPermutation());
3424     getInMutable().assign(otherTransposeOp.getIn());
3425     setPermutation(composedPermutation);
3426     return getResult();
3427   }
3428   return {};
3429 }
3430 
3431 //===----------------------------------------------------------------------===//
3432 // ViewOp
3433 //===----------------------------------------------------------------------===//
3434 
3435 void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3436   setNameFn(getResult(), "view");
3437 }
3438 
3439 LogicalResult ViewOp::verify() {
3440   auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3441   auto viewType = getType();
3442 
3443   // The base memref should have identity layout map (or none).
3444   if (!baseType.getLayout().isIdentity())
3445     return emitError("unsupported map for base memref type ") << baseType;
3446 
3447   // The result memref should have identity layout map (or none).
3448   if (!viewType.getLayout().isIdentity())
3449     return emitError("unsupported map for result memref type ") << viewType;
3450 
3451   // The base memref and the view memref should be in the same memory space.
3452   if (baseType.getMemorySpace() != viewType.getMemorySpace())
3453     return emitError("different memory spaces specified for base memref "
3454                      "type ")
3455            << baseType << " and view memref type " << viewType;
3456 
3457   // Verify that we have the correct number of sizes for the result type.
3458   unsigned numDynamicDims = viewType.getNumDynamicDims();
3459   if (getSizes().size() != numDynamicDims)
3460     return emitError("incorrect number of size operands for type ") << viewType;
3461 
3462   return success();
3463 }
3464 
3465 Value ViewOp::getViewSource() { return getSource(); }
3466 
3467 namespace {
3468 
3469 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3470   using OpRewritePattern<ViewOp>::OpRewritePattern;
3471 
3472   LogicalResult matchAndRewrite(ViewOp viewOp,
3473                                 PatternRewriter &rewriter) const override {
3474     // Return if none of the operands are constants.
3475     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3476           return matchPattern(operand, matchConstantIndex());
3477         }))
3478       return failure();
3479 
3480     // Get result memref type.
3481     auto memrefType = viewOp.getType();
3482 
3483     // Get offset from old memref view type 'memRefType'.
3484     int64_t oldOffset;
3485     SmallVector<int64_t, 4> oldStrides;
3486     if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3487       return failure();
3488     assert(oldOffset == 0 && "Expected 0 offset");
3489 
3490     SmallVector<Value, 4> newOperands;
3491 
3492     // Offset cannot be folded into result type.
3493 
3494     // Fold any dynamic dim operands which are produced by a constant.
3495     SmallVector<int64_t, 4> newShapeConstants;
3496     newShapeConstants.reserve(memrefType.getRank());
3497 
3498     unsigned dynamicDimPos = 0;
3499     unsigned rank = memrefType.getRank();
3500     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3501       int64_t dimSize = memrefType.getDimSize(dim);
3502       // If this is already static dimension, keep it.
3503       if (!ShapedType::isDynamic(dimSize)) {
3504         newShapeConstants.push_back(dimSize);
3505         continue;
3506       }
3507       auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3508       if (auto constantIndexOp =
3509               dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3510         // Dynamic shape dimension will be folded.
3511         newShapeConstants.push_back(constantIndexOp.value());
3512       } else {
3513         // Dynamic shape dimension not folded; copy operand from old memref.
3514         newShapeConstants.push_back(dimSize);
3515         newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3516       }
3517       dynamicDimPos++;
3518     }
3519 
3520     // Create new memref type with constant folded dims.
3521     MemRefType newMemRefType =
3522         MemRefType::Builder(memrefType).setShape(newShapeConstants);
3523     // Nothing new, don't fold.
3524     if (newMemRefType == memrefType)
3525       return failure();
3526 
3527     // Create new ViewOp.
3528     auto newViewOp = rewriter.create<ViewOp>(
3529         viewOp.getLoc(), newMemRefType, viewOp.getOperand(0),
3530         viewOp.getByteShift(), newOperands);
3531     // Insert a cast so we have the same type as the old memref type.
3532     rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3533     return success();
3534   }
3535 };
3536 
3537 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3538   using OpRewritePattern<ViewOp>::OpRewritePattern;
3539 
3540   LogicalResult matchAndRewrite(ViewOp viewOp,
3541                                 PatternRewriter &rewriter) const override {
3542     Value memrefOperand = viewOp.getOperand(0);
3543     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3544     if (!memrefCastOp)
3545       return failure();
3546     Value allocOperand = memrefCastOp.getOperand();
3547     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3548     if (!allocOp)
3549       return failure();
3550     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3551                                         viewOp.getByteShift(),
3552                                         viewOp.getSizes());
3553     return success();
3554   }
3555 };
3556 
3557 } // namespace
3558 
3559 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3560                                          MLIRContext *context) {
3561   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3562 }
3563 
3564 //===----------------------------------------------------------------------===//
3565 // AtomicRMWOp
3566 //===----------------------------------------------------------------------===//
3567 
3568 LogicalResult AtomicRMWOp::verify() {
3569   if (getMemRefType().getRank() != getNumOperands() - 2)
3570     return emitOpError(
3571         "expects the number of subscripts to be equal to memref rank");
3572   switch (getKind()) {
3573   case arith::AtomicRMWKind::addf:
3574   case arith::AtomicRMWKind::maximumf:
3575   case arith::AtomicRMWKind::minimumf:
3576   case arith::AtomicRMWKind::mulf:
3577     if (!llvm::isa<FloatType>(getValue().getType()))
3578       return emitOpError() << "with kind '"
3579                            << arith::stringifyAtomicRMWKind(getKind())
3580                            << "' expects a floating-point type";
3581     break;
3582   case arith::AtomicRMWKind::addi:
3583   case arith::AtomicRMWKind::maxs:
3584   case arith::AtomicRMWKind::maxu:
3585   case arith::AtomicRMWKind::mins:
3586   case arith::AtomicRMWKind::minu:
3587   case arith::AtomicRMWKind::muli:
3588   case arith::AtomicRMWKind::ori:
3589   case arith::AtomicRMWKind::andi:
3590     if (!llvm::isa<IntegerType>(getValue().getType()))
3591       return emitOpError() << "with kind '"
3592                            << arith::stringifyAtomicRMWKind(getKind())
3593                            << "' expects an integer type";
3594     break;
3595   default:
3596     break;
3597   }
3598   return success();
3599 }
3600 
3601 OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3602   /// atomicrmw(memrefcast) -> atomicrmw
3603   if (succeeded(foldMemRefCast(*this, getValue())))
3604     return getResult();
3605   return OpFoldResult();
3606 }
3607 
3608 //===----------------------------------------------------------------------===//
3609 // TableGen'd op method definitions
3610 //===----------------------------------------------------------------------===//
3611 
3612 #define GET_OP_CLASSES
3613 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
3614