xref: /llvm-project/mlir/lib/Dialect/Affine/IR/AffineOps.cpp (revision d28a4f1fc02dc34a87fa22af0a053e8f1e7f6cea)
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/UB/IR/UBOps.h"
13 #include "mlir/Dialect/Utils/StaticValueUtils.h"
14 #include "mlir/IR/AffineExprVisitor.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/IntegerSet.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Interfaces/ShapedOpInterfaces.h"
21 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
22 #include "mlir/Transforms/InliningUtils.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallVectorExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <numeric>
31 #include <optional>
32 
33 using namespace mlir;
34 using namespace mlir::affine;
35 
36 using llvm::divideCeilSigned;
37 using llvm::divideFloorSigned;
38 using llvm::mod;
39 
40 #define DEBUG_TYPE "affine-ops"
41 
42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
43 
44 /// A utility function to check if a value is defined at the top level of
45 /// `region` or is an argument of `region`. A value of index type defined at the
46 /// top level of a `AffineScope` region is always a valid symbol for all
47 /// uses in that region.
48 bool mlir::affine::isTopLevelValue(Value value, Region *region) {
49   if (auto arg = llvm::dyn_cast<BlockArgument>(value))
50     return arg.getParentRegion() == region;
51   return value.getDefiningOp()->getParentRegion() == region;
52 }
53 
54 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
55 /// region remains legal if the operation that uses it is inlined into `dest`
56 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
57 /// `isValidSymbol`, depending on the value being required to remain a valid
58 /// dimension or symbol.
59 static bool
60 remainsLegalAfterInline(Value value, Region *src, Region *dest,
61                         const IRMapping &mapping,
62                         function_ref<bool(Value, Region *)> legalityCheck) {
63   // If the value is a valid dimension for any other reason than being
64   // a top-level value, it will remain valid: constants get inlined
65   // with the function, transitive affine applies also get inlined and
66   // will be checked themselves, etc.
67   if (!isTopLevelValue(value, src))
68     return true;
69 
70   // If it's a top-level value because it's a block operand, i.e. a
71   // function argument, check whether the value replacing it after
72   // inlining is a valid dimension in the new region.
73   if (llvm::isa<BlockArgument>(value))
74     return legalityCheck(mapping.lookup(value), dest);
75 
76   // If it's a top-level value because it's defined in the region,
77   // it can only be inlined if the defining op is a constant or a
78   // `dim`, which can appear anywhere and be valid, since the defining
79   // op won't be top-level anymore after inlining.
80   Attribute operandCst;
81   bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
82   return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
83          isDimLikeOp;
84 }
85 
86 /// Checks if all values known to be legal affine dimensions or symbols in `src`
87 /// remain so if their respective users are inlined into `dest`.
88 static bool
89 remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
90                         const IRMapping &mapping,
91                         function_ref<bool(Value, Region *)> legalityCheck) {
92   return llvm::all_of(values, [&](Value v) {
93     return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
94   });
95 }
96 
97 /// Checks if an affine read or write operation remains legal after inlining
98 /// from `src` to `dest`.
99 template <typename OpTy>
100 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
101                                     const IRMapping &mapping) {
102   static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
103                                 AffineWriteOpInterface>::value,
104                 "only ops with affine read/write interface are supported");
105 
106   AffineMap map = op.getAffineMap();
107   ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
108   ValueRange symbolOperands =
109       op.getMapOperands().take_back(map.getNumSymbols());
110   if (!remainsLegalAfterInline(
111           dimOperands, src, dest, mapping,
112           static_cast<bool (*)(Value, Region *)>(isValidDim)))
113     return false;
114   if (!remainsLegalAfterInline(
115           symbolOperands, src, dest, mapping,
116           static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
117     return false;
118   return true;
119 }
120 
121 /// Checks if an affine apply operation remains legal after inlining from `src`
122 /// to `dest`.
123 //  Use "unused attribute" marker to silence clang-tidy warning stemming from
124 //  the inability to see through "llvm::TypeSwitch".
125 template <>
126 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
127                                                    Region *src, Region *dest,
128                                                    const IRMapping &mapping) {
129   // If it's a valid dimension, we need to check that it remains so.
130   if (isValidDim(op.getResult(), src))
131     return remainsLegalAfterInline(
132         op.getMapOperands(), src, dest, mapping,
133         static_cast<bool (*)(Value, Region *)>(isValidDim));
134 
135   // Otherwise it must be a valid symbol, check that it remains so.
136   return remainsLegalAfterInline(
137       op.getMapOperands(), src, dest, mapping,
138       static_cast<bool (*)(Value, Region *)>(isValidSymbol));
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // AffineDialect Interfaces
143 //===----------------------------------------------------------------------===//
144 
145 namespace {
146 /// This class defines the interface for handling inlining with affine
147 /// operations.
148 struct AffineInlinerInterface : public DialectInlinerInterface {
149   using DialectInlinerInterface::DialectInlinerInterface;
150 
151   //===--------------------------------------------------------------------===//
152   // Analysis Hooks
153   //===--------------------------------------------------------------------===//
154 
155   /// Returns true if the given region 'src' can be inlined into the region
156   /// 'dest' that is attached to an operation registered to the current dialect.
157   /// 'wouldBeCloned' is set if the region is cloned into its new location
158   /// rather than moved, indicating there may be other users.
159   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
160                        IRMapping &valueMapping) const final {
161     // We can inline into affine loops and conditionals if this doesn't break
162     // affine value categorization rules.
163     Operation *destOp = dest->getParentOp();
164     if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
165       return false;
166 
167     // Multi-block regions cannot be inlined into affine constructs, all of
168     // which require single-block regions.
169     if (!llvm::hasSingleElement(*src))
170       return false;
171 
172     // Side-effecting operations that the affine dialect cannot understand
173     // should not be inlined.
174     Block &srcBlock = src->front();
175     for (Operation &op : srcBlock) {
176       // Ops with no side effects are fine,
177       if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
178         if (iface.hasNoEffect())
179           continue;
180       }
181 
182       // Assuming the inlined region is valid, we only need to check if the
183       // inlining would change it.
184       bool remainsValid =
185           llvm::TypeSwitch<Operation *, bool>(&op)
186               .Case<AffineApplyOp, AffineReadOpInterface,
187                     AffineWriteOpInterface>([&](auto op) {
188                 return remainsLegalAfterInline(op, src, dest, valueMapping);
189               })
190               .Default([](Operation *) {
191                 // Conservatively disallow inlining ops we cannot reason about.
192                 return false;
193               });
194 
195       if (!remainsValid)
196         return false;
197     }
198 
199     return true;
200   }
201 
202   /// Returns true if the given operation 'op', that is registered to this
203   /// dialect, can be inlined into the given region, false otherwise.
204   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
205                        IRMapping &valueMapping) const final {
206     // Always allow inlining affine operations into a region that is marked as
207     // affine scope, or into affine loops and conditionals. There are some edge
208     // cases when inlining *into* affine structures, but that is handled in the
209     // other 'isLegalToInline' hook above.
210     Operation *parentOp = region->getParentOp();
211     return parentOp->hasTrait<OpTrait::AffineScope>() ||
212            isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
213   }
214 
215   /// Affine regions should be analyzed recursively.
216   bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
217 };
218 } // namespace
219 
220 //===----------------------------------------------------------------------===//
221 // AffineDialect
222 //===----------------------------------------------------------------------===//
223 
224 void AffineDialect::initialize() {
225   addOperations<AffineDmaStartOp, AffineDmaWaitOp,
226 #define GET_OP_LIST
227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
228                 >();
229   addInterfaces<AffineInlinerInterface>();
230   declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
231                             AffineMinOp>();
232 }
233 
234 /// Materialize a single constant operation from a given attribute value with
235 /// the desired resultant type.
236 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
237                                               Attribute value, Type type,
238                                               Location loc) {
239   if (auto poison = dyn_cast<ub::PoisonAttr>(value))
240     return builder.create<ub::PoisonOp>(loc, type, poison);
241   return arith::ConstantOp::materialize(builder, value, type, loc);
242 }
243 
244 /// A utility function to check if a value is defined at the top level of an
245 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
246 /// conservatively assume it is not top-level. A value of index type defined at
247 /// the top level is always a valid symbol.
248 bool mlir::affine::isTopLevelValue(Value value) {
249   if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
250     // The block owning the argument may be unlinked, e.g. when the surrounding
251     // region has not yet been attached to an Op, at which point the parent Op
252     // is null.
253     Operation *parentOp = arg.getOwner()->getParentOp();
254     return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
255   }
256   // The defining Op may live in an unlinked block so its parent Op may be null.
257   Operation *parentOp = value.getDefiningOp()->getParentOp();
258   return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
259 }
260 
261 /// Returns the closest region enclosing `op` that is held by an operation with
262 /// trait `AffineScope`; `nullptr` if there is no such region.
263 Region *mlir::affine::getAffineScope(Operation *op) {
264   auto *curOp = op;
265   while (auto *parentOp = curOp->getParentOp()) {
266     if (parentOp->hasTrait<OpTrait::AffineScope>())
267       return curOp->getParentRegion();
268     curOp = parentOp;
269   }
270   return nullptr;
271 }
272 
273 // A Value can be used as a dimension id iff it meets one of the following
274 // conditions:
275 // *) It is valid as a symbol.
276 // *) It is an induction variable.
277 // *) It is the result of affine apply operation with dimension id arguments.
278 bool mlir::affine::isValidDim(Value value) {
279   // The value must be an index type.
280   if (!value.getType().isIndex())
281     return false;
282 
283   if (auto *defOp = value.getDefiningOp())
284     return isValidDim(value, getAffineScope(defOp));
285 
286   // This value has to be a block argument for an op that has the
287   // `AffineScope` trait or for an affine.for or affine.parallel.
288   auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
289   return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
290                       isa<AffineForOp, AffineParallelOp>(parentOp));
291 }
292 
293 // Value can be used as a dimension id iff it meets one of the following
294 // conditions:
295 // *) It is valid as a symbol.
296 // *) It is an induction variable.
297 // *) It is the result of an affine apply operation with dimension id operands.
298 bool mlir::affine::isValidDim(Value value, Region *region) {
299   // The value must be an index type.
300   if (!value.getType().isIndex())
301     return false;
302 
303   // All valid symbols are okay.
304   if (isValidSymbol(value, region))
305     return true;
306 
307   auto *op = value.getDefiningOp();
308   if (!op) {
309     // This value has to be a block argument for an affine.for or an
310     // affine.parallel.
311     auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
312     return isa<AffineForOp, AffineParallelOp>(parentOp);
313   }
314 
315   // Affine apply operation is ok if all of its operands are ok.
316   if (auto applyOp = dyn_cast<AffineApplyOp>(op))
317     return applyOp.isValidDim(region);
318   // The dim op is okay if its operand memref/tensor is defined at the top
319   // level.
320   if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
321     return isTopLevelValue(dimOp.getShapedValue());
322   return false;
323 }
324 
325 /// Returns true if the 'index' dimension of the `memref` defined by
326 /// `memrefDefOp` is a statically  shaped one or defined using a valid symbol
327 /// for `region`.
328 template <typename AnyMemRefDefOp>
329 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
330                                     Region *region) {
331   MemRefType memRefType = memrefDefOp.getType();
332 
333   // Dimension index is out of bounds.
334   if (index >= memRefType.getRank()) {
335     return false;
336   }
337 
338   // Statically shaped.
339   if (!memRefType.isDynamicDim(index))
340     return true;
341   // Get the position of the dimension among dynamic dimensions;
342   unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
343   return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
344                        region);
345 }
346 
347 /// Returns true if the result of the dim op is a valid symbol for `region`.
348 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
349   // The dim op is okay if its source is defined at the top level.
350   if (isTopLevelValue(dimOp.getShapedValue()))
351     return true;
352 
353   // Conservatively handle remaining BlockArguments as non-valid symbols.
354   // E.g. scf.for iterArgs.
355   if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
356     return false;
357 
358   // The dim op is also okay if its operand memref is a view/subview whose
359   // corresponding size is a valid symbol.
360   std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
361 
362   // Be conservative if we can't understand the dimension.
363   if (!index.has_value())
364     return false;
365 
366   // Skip over all memref.cast ops (if any).
367   Operation *op = dimOp.getShapedValue().getDefiningOp();
368   while (auto castOp = dyn_cast<memref::CastOp>(op)) {
369     // Bail on unranked memrefs.
370     if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
371       return false;
372     op = castOp.getSource().getDefiningOp();
373     if (!op)
374       return false;
375   }
376 
377   int64_t i = index.value();
378   return TypeSwitch<Operation *, bool>(op)
379       .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
380           [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
381       .Default([](Operation *) { return false; });
382 }
383 
384 // A value can be used as a symbol (at all its use sites) iff it meets one of
385 // the following conditions:
386 // *) It is a constant.
387 // *) Its defining op or block arg appearance is immediately enclosed by an op
388 //    with `AffineScope` trait.
389 // *) It is the result of an affine.apply operation with symbol operands.
390 // *) It is a result of the dim op on a memref whose corresponding size is a
391 //    valid symbol.
392 bool mlir::affine::isValidSymbol(Value value) {
393   if (!value)
394     return false;
395 
396   // The value must be an index type.
397   if (!value.getType().isIndex())
398     return false;
399 
400   // Check that the value is a top level value.
401   if (isTopLevelValue(value))
402     return true;
403 
404   if (auto *defOp = value.getDefiningOp())
405     return isValidSymbol(value, getAffineScope(defOp));
406 
407   return false;
408 }
409 
410 /// A value can be used as a symbol for `region` iff it meets one of the
411 /// following conditions:
412 /// *) It is a constant.
413 /// *) It is a result of a `Pure` operation whose operands are valid symbolic
414 /// *) identifiers.
415 /// *) It is a result of the dim op on a memref whose corresponding size is
416 ///    a valid symbol.
417 /// *) It is defined at the top level of 'region' or is its argument.
418 /// *) It dominates `region`'s parent op.
419 /// If `region` is null, conservatively assume the symbol definition scope does
420 /// not exist and only accept the values that would be symbols regardless of
421 /// the surrounding region structure, i.e. the first three cases above.
422 bool mlir::affine::isValidSymbol(Value value, Region *region) {
423   // The value must be an index type.
424   if (!value.getType().isIndex())
425     return false;
426 
427   // A top-level value is a valid symbol.
428   if (region && ::isTopLevelValue(value, region))
429     return true;
430 
431   auto *defOp = value.getDefiningOp();
432   if (!defOp) {
433     // A block argument that is not a top-level value is a valid symbol if it
434     // dominates region's parent op.
435     Operation *regionOp = region ? region->getParentOp() : nullptr;
436     if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
437       if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
438         return isValidSymbol(value, parentOpRegion);
439     return false;
440   }
441 
442   // Constant operation is ok.
443   Attribute operandCst;
444   if (matchPattern(defOp, m_Constant(&operandCst)))
445     return true;
446 
447   // `Pure` operation that whose operands are valid symbolic identifiers.
448   if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
449         return affine::isValidSymbol(operand, region);
450       })) {
451     return true;
452   }
453 
454   // Dim op results could be valid symbols at any level.
455   if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
456     return isDimOpValidSymbol(dimOp, region);
457 
458   // Check for values dominating `region`'s parent op.
459   Operation *regionOp = region ? region->getParentOp() : nullptr;
460   if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
461     if (auto *parentRegion = region->getParentOp()->getParentRegion())
462       return isValidSymbol(value, parentRegion);
463 
464   return false;
465 }
466 
467 // Returns true if 'value' is a valid index to an affine operation (e.g.
468 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
469 // `region` provides the polyhedral symbol scope. Returns false otherwise.
470 static bool isValidAffineIndexOperand(Value value, Region *region) {
471   return isValidDim(value, region) || isValidSymbol(value, region);
472 }
473 
474 /// Prints dimension and symbol list.
475 static void printDimAndSymbolList(Operation::operand_iterator begin,
476                                   Operation::operand_iterator end,
477                                   unsigned numDims, OpAsmPrinter &printer) {
478   OperandRange operands(begin, end);
479   printer << '(' << operands.take_front(numDims) << ')';
480   if (operands.size() > numDims)
481     printer << '[' << operands.drop_front(numDims) << ']';
482 }
483 
484 /// Parses dimension and symbol list and returns true if parsing failed.
485 ParseResult mlir::affine::parseDimAndSymbolList(
486     OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
487   SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
488   if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
489     return failure();
490   // Store number of dimensions for validation by caller.
491   numDims = opInfos.size();
492 
493   // Parse the optional symbol operands.
494   auto indexTy = parser.getBuilder().getIndexType();
495   return failure(parser.parseOperandList(
496                      opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
497                  parser.resolveOperands(opInfos, indexTy, operands));
498 }
499 
500 /// Utility function to verify that a set of operands are valid dimension and
501 /// symbol identifiers. The operands should be laid out such that the dimension
502 /// operands are before the symbol operands. This function returns failure if
503 /// there was an invalid operand. An operation is provided to emit any necessary
504 /// errors.
505 template <typename OpTy>
506 static LogicalResult
507 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
508                               unsigned numDims) {
509   unsigned opIt = 0;
510   for (auto operand : operands) {
511     if (opIt++ < numDims) {
512       if (!isValidDim(operand, getAffineScope(op)))
513         return op.emitOpError("operand cannot be used as a dimension id");
514     } else if (!isValidSymbol(operand, getAffineScope(op))) {
515       return op.emitOpError("operand cannot be used as a symbol");
516     }
517   }
518   return success();
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // AffineApplyOp
523 //===----------------------------------------------------------------------===//
524 
525 AffineValueMap AffineApplyOp::getAffineValueMap() {
526   return AffineValueMap(getAffineMap(), getOperands(), getResult());
527 }
528 
529 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
530   auto &builder = parser.getBuilder();
531   auto indexTy = builder.getIndexType();
532 
533   AffineMapAttr mapAttr;
534   unsigned numDims;
535   if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
536       parseDimAndSymbolList(parser, result.operands, numDims) ||
537       parser.parseOptionalAttrDict(result.attributes))
538     return failure();
539   auto map = mapAttr.getValue();
540 
541   if (map.getNumDims() != numDims ||
542       numDims + map.getNumSymbols() != result.operands.size()) {
543     return parser.emitError(parser.getNameLoc(),
544                             "dimension or symbol index mismatch");
545   }
546 
547   result.types.append(map.getNumResults(), indexTy);
548   return success();
549 }
550 
551 void AffineApplyOp::print(OpAsmPrinter &p) {
552   p << " " << getMapAttr();
553   printDimAndSymbolList(operand_begin(), operand_end(),
554                         getAffineMap().getNumDims(), p);
555   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
556 }
557 
558 LogicalResult AffineApplyOp::verify() {
559   // Check input and output dimensions match.
560   AffineMap affineMap = getMap();
561 
562   // Verify that operand count matches affine map dimension and symbol count.
563   if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
564     return emitOpError(
565         "operand count and affine map dimension and symbol count must match");
566 
567   // Verify that the map only produces one result.
568   if (affineMap.getNumResults() != 1)
569     return emitOpError("mapping must produce one value");
570 
571   return success();
572 }
573 
574 // The result of the affine apply operation can be used as a dimension id if all
575 // its operands are valid dimension ids.
576 bool AffineApplyOp::isValidDim() {
577   return llvm::all_of(getOperands(),
578                       [](Value op) { return affine::isValidDim(op); });
579 }
580 
581 // The result of the affine apply operation can be used as a dimension id if all
582 // its operands are valid dimension ids with the parent operation of `region`
583 // defining the polyhedral scope for symbols.
584 bool AffineApplyOp::isValidDim(Region *region) {
585   return llvm::all_of(getOperands(),
586                       [&](Value op) { return ::isValidDim(op, region); });
587 }
588 
589 // The result of the affine apply operation can be used as a symbol if all its
590 // operands are symbols.
591 bool AffineApplyOp::isValidSymbol() {
592   return llvm::all_of(getOperands(),
593                       [](Value op) { return affine::isValidSymbol(op); });
594 }
595 
596 // The result of the affine apply operation can be used as a symbol in `region`
597 // if all its operands are symbols in `region`.
598 bool AffineApplyOp::isValidSymbol(Region *region) {
599   return llvm::all_of(getOperands(), [&](Value operand) {
600     return affine::isValidSymbol(operand, region);
601   });
602 }
603 
604 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
605   auto map = getAffineMap();
606 
607   // Fold dims and symbols to existing values.
608   auto expr = map.getResult(0);
609   if (auto dim = dyn_cast<AffineDimExpr>(expr))
610     return getOperand(dim.getPosition());
611   if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
612     return getOperand(map.getNumDims() + sym.getPosition());
613 
614   // Otherwise, default to folding the map.
615   SmallVector<Attribute, 1> result;
616   bool hasPoison = false;
617   auto foldResult =
618       map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
619   if (hasPoison)
620     return ub::PoisonAttr::get(getContext());
621   if (failed(foldResult))
622     return {};
623   return result[0];
624 }
625 
626 /// Returns the largest known divisor of `e`. Exploits information from the
627 /// values in `operands`.
628 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
629   // This method isn't aware of `operands`.
630   int64_t div = e.getLargestKnownDivisor();
631 
632   // We now make use of operands for the case `e` is a dim expression.
633   // TODO: More powerful simplification would have to modify
634   // getLargestKnownDivisor to take `operands` and exploit that information as
635   // well for dim/sym expressions, but in that case, getLargestKnownDivisor
636   // can't be part of the IR library but of the `Analysis` library. The IR
637   // library can only really depend on simple O(1) checks.
638   auto dimExpr = dyn_cast<AffineDimExpr>(e);
639   // If it's not a dim expr, `div` is the best we have.
640   if (!dimExpr)
641     return div;
642 
643   // We simply exploit information from loop IVs.
644   // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
645   // desired simplifications are expected to be part of other
646   // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
647   // LoopAnalysis library.
648   Value operand = operands[dimExpr.getPosition()];
649   int64_t operandDivisor = 1;
650   // TODO: With the right accessors, this can be extended to
651   // LoopLikeOpInterface.
652   if (AffineForOp forOp = getForInductionVarOwner(operand)) {
653     if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
654       operandDivisor = forOp.getStepAsInt();
655     } else {
656       uint64_t lbLargestKnownDivisor =
657           forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
658       operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
659     }
660   }
661   return operandDivisor;
662 }
663 
664 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
665 /// being an affine dim expression or a constant.
666 static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef<Value> operands,
667                                    int64_t k) {
668   if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
669     int64_t constVal = constExpr.getValue();
670     return constVal >= 0 && constVal < k;
671   }
672   auto dimExpr = dyn_cast<AffineDimExpr>(e);
673   if (!dimExpr)
674     return false;
675   Value operand = operands[dimExpr.getPosition()];
676   // TODO: With the right accessors, this can be extended to
677   // LoopLikeOpInterface.
678   if (AffineForOp forOp = getForInductionVarOwner(operand)) {
679     if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
680         forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
681       return true;
682     }
683   }
684 
685   // We don't consider other cases like `operand` being defined by a constant or
686   // an affine.apply op since such cases will already be handled by other
687   // patterns and propagation of loop IVs or constant would happen.
688   return false;
689 }
690 
691 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
692 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
693 /// expression is in that form.
694 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
695                            AffineExpr &quotientTimesDiv, AffineExpr &rem) {
696   auto bin = dyn_cast<AffineBinaryOpExpr>(e);
697   if (!bin || bin.getKind() != AffineExprKind::Add)
698     return false;
699 
700   AffineExpr llhs = bin.getLHS();
701   AffineExpr rlhs = bin.getRHS();
702   div = getLargestKnownDivisor(llhs, operands);
703   if (isNonNegativeBoundedBy(rlhs, operands, div)) {
704     quotientTimesDiv = llhs;
705     rem = rlhs;
706     return true;
707   }
708   div = getLargestKnownDivisor(rlhs, operands);
709   if (isNonNegativeBoundedBy(llhs, operands, div)) {
710     quotientTimesDiv = rlhs;
711     rem = llhs;
712     return true;
713   }
714   return false;
715 }
716 
717 /// Gets the constant lower bound on an `iv`.
718 static std::optional<int64_t> getLowerBound(Value iv) {
719   AffineForOp forOp = getForInductionVarOwner(iv);
720   if (forOp && forOp.hasConstantLowerBound())
721     return forOp.getConstantLowerBound();
722   return std::nullopt;
723 }
724 
725 /// Gets the constant upper bound on an affine.for `iv`.
726 static std::optional<int64_t> getUpperBound(Value iv) {
727   AffineForOp forOp = getForInductionVarOwner(iv);
728   if (!forOp || !forOp.hasConstantUpperBound())
729     return std::nullopt;
730 
731   // If its lower bound is also known, we can get a more precise bound
732   // whenever the step is not one.
733   if (forOp.hasConstantLowerBound()) {
734     return forOp.getConstantUpperBound() - 1 -
735            (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
736                forOp.getStepAsInt();
737   }
738   return forOp.getConstantUpperBound() - 1;
739 }
740 
741 /// Determine a constant upper bound for `expr` if one exists while exploiting
742 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
743 /// is guaranteed to be less than or equal to it.
744 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
745                                             unsigned numSymbols,
746                                             ArrayRef<Value> operands) {
747   // Get the constant lower or upper bounds on the operands.
748   SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
749   constLowerBounds.reserve(operands.size());
750   constUpperBounds.reserve(operands.size());
751   for (Value operand : operands) {
752     constLowerBounds.push_back(getLowerBound(operand));
753     constUpperBounds.push_back(getUpperBound(operand));
754   }
755 
756   if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
757     return constExpr.getValue();
758 
759   return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
760                                constUpperBounds,
761                                /*isUpper=*/true);
762 }
763 
764 /// Determine a constant lower bound for `expr` if one exists while exploiting
765 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
766 /// is guaranteed to be less than or equal to it.
767 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
768                                             unsigned numSymbols,
769                                             ArrayRef<Value> operands) {
770   // Get the constant lower or upper bounds on the operands.
771   SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
772   constLowerBounds.reserve(operands.size());
773   constUpperBounds.reserve(operands.size());
774   for (Value operand : operands) {
775     constLowerBounds.push_back(getLowerBound(operand));
776     constUpperBounds.push_back(getUpperBound(operand));
777   }
778 
779   std::optional<int64_t> lowerBound;
780   if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
781     lowerBound = constExpr.getValue();
782   } else {
783     lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
784                                        constLowerBounds, constUpperBounds,
785                                        /*isUpper=*/false);
786   }
787   return lowerBound;
788 }
789 
790 /// Simplify `expr` while exploiting information from the values in `operands`.
791 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
792                                     unsigned numSymbols,
793                                     ArrayRef<Value> operands) {
794   // We do this only for certain floordiv/mod expressions.
795   auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
796   if (!binExpr)
797     return;
798 
799   // Simplify the child expressions first.
800   AffineExpr lhs = binExpr.getLHS();
801   AffineExpr rhs = binExpr.getRHS();
802   simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
803   simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
804   expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
805 
806   binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
807   if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
808                    expr.getKind() != AffineExprKind::CeilDiv &&
809                    expr.getKind() != AffineExprKind::Mod)) {
810     return;
811   }
812 
813   // The `lhs` and `rhs` may be different post construction of simplified expr.
814   lhs = binExpr.getLHS();
815   rhs = binExpr.getRHS();
816   auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
817   if (!rhsConst)
818     return;
819 
820   int64_t rhsConstVal = rhsConst.getValue();
821   // Undefined exprsessions aren't touched; IR can still be valid with them.
822   if (rhsConstVal <= 0)
823     return;
824 
825   // Exploit constant lower/upper bounds to simplify a floordiv or mod.
826   MLIRContext *context = expr.getContext();
827   std::optional<int64_t> lhsLbConst =
828       getLowerBound(lhs, numDims, numSymbols, operands);
829   std::optional<int64_t> lhsUbConst =
830       getUpperBound(lhs, numDims, numSymbols, operands);
831   if (lhsLbConst && lhsUbConst) {
832     int64_t lhsLbConstVal = *lhsLbConst;
833     int64_t lhsUbConstVal = *lhsUbConst;
834     // lhs floordiv c is a single value lhs is bounded in a range `c` that has
835     // the same quotient.
836     if (binExpr.getKind() == AffineExprKind::FloorDiv &&
837         divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
838             divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
839       expr = getAffineConstantExpr(
840           divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
841       return;
842     }
843     // lhs ceildiv c is a single value if the entire range has the same ceil
844     // quotient.
845     if (binExpr.getKind() == AffineExprKind::CeilDiv &&
846         divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
847             divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
848       expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
849                                    context);
850       return;
851     }
852     // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
853     if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
854         lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
855       expr = lhs;
856       return;
857     }
858   }
859 
860   // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
861   // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
862   // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
863   // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
864   AffineExpr quotientTimesDiv, rem;
865   int64_t divisor;
866   if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
867     if (rhsConstVal % divisor == 0 &&
868         binExpr.getKind() == AffineExprKind::FloorDiv) {
869       expr = quotientTimesDiv.floorDiv(rhsConst);
870     } else if (divisor % rhsConstVal == 0 &&
871                binExpr.getKind() == AffineExprKind::Mod) {
872       expr = rem % rhsConst;
873     }
874     return;
875   }
876 
877   // Handle the simple case when the LHS expression can be either upper
878   // bounded or is a known multiple of RHS constant.
879   // lhs floordiv c -> 0 if 0 <= lhs < c,
880   // lhs mod c -> 0 if lhs % c = 0.
881   if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
882        binExpr.getKind() == AffineExprKind::FloorDiv) ||
883       (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
884        binExpr.getKind() == AffineExprKind::Mod)) {
885     expr = getAffineConstantExpr(0, expr.getContext());
886   }
887 }
888 
889 /// Simplify the expressions in `map` while making use of lower or upper bounds
890 /// of its operands. If `isMax` is true, the map is to be treated as a max of
891 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
892 /// d1) can be simplified to (8) if the operands are respectively lower bounded
893 /// by 2 and 0 (the second expression can't be lower than 8).
894 static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
895                                              ArrayRef<Value> operands,
896                                              bool isMax) {
897   // Can't simplify.
898   if (operands.empty())
899     return;
900 
901   // Get the upper or lower bound on an affine.for op IV using its range.
902   // Get the constant lower or upper bounds on the operands.
903   SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
904   constLowerBounds.reserve(operands.size());
905   constUpperBounds.reserve(operands.size());
906   for (Value operand : operands) {
907     constLowerBounds.push_back(getLowerBound(operand));
908     constUpperBounds.push_back(getUpperBound(operand));
909   }
910 
911   // We will compute the lower and upper bounds on each of the expressions
912   // Then, we will check (depending on max or min) as to whether a specific
913   // bound is redundant by checking if its highest (in case of max) and its
914   // lowest (in the case of min) value is already lower than (or higher than)
915   // the lower bound (or upper bound in the case of min) of another bound.
916   SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
917   lowerBounds.reserve(map.getNumResults());
918   upperBounds.reserve(map.getNumResults());
919   for (AffineExpr e : map.getResults()) {
920     if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
921       lowerBounds.push_back(constExpr.getValue());
922       upperBounds.push_back(constExpr.getValue());
923     } else {
924       lowerBounds.push_back(
925           getBoundForAffineExpr(e, map.getNumDims(), map.getNumSymbols(),
926                                 constLowerBounds, constUpperBounds,
927                                 /*isUpper=*/false));
928       upperBounds.push_back(
929           getBoundForAffineExpr(e, map.getNumDims(), map.getNumSymbols(),
930                                 constLowerBounds, constUpperBounds,
931                                 /*isUpper=*/true));
932     }
933   }
934 
935   // Collect expressions that are not redundant.
936   SmallVector<AffineExpr, 4> irredundantExprs;
937   for (auto exprEn : llvm::enumerate(map.getResults())) {
938     AffineExpr e = exprEn.value();
939     unsigned i = exprEn.index();
940     // Some expressions can be turned into constants.
941     if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
942       e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
943 
944     // Check if the expression is redundant.
945     if (isMax) {
946       if (!upperBounds[i]) {
947         irredundantExprs.push_back(e);
948         continue;
949       }
950       // If there exists another expression such that its lower bound is greater
951       // than this expression's upper bound, it's redundant.
952       if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
953             auto otherLowerBound = en.value();
954             unsigned pos = en.index();
955             if (pos == i || !otherLowerBound)
956               return false;
957             if (*otherLowerBound > *upperBounds[i])
958               return true;
959             if (*otherLowerBound < *upperBounds[i])
960               return false;
961             // Equality case. When both expressions are considered redundant, we
962             // don't want to get both of them. We keep the one that appears
963             // first.
964             if (upperBounds[pos] && lowerBounds[i] &&
965                 lowerBounds[i] == upperBounds[i] &&
966                 otherLowerBound == *upperBounds[pos] && i < pos)
967               return false;
968             return true;
969           }))
970         irredundantExprs.push_back(e);
971     } else {
972       if (!lowerBounds[i]) {
973         irredundantExprs.push_back(e);
974         continue;
975       }
976       // Likewise for the `min` case. Use the complement of the condition above.
977       if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
978             auto otherUpperBound = en.value();
979             unsigned pos = en.index();
980             if (pos == i || !otherUpperBound)
981               return false;
982             if (*otherUpperBound < *lowerBounds[i])
983               return true;
984             if (*otherUpperBound > *lowerBounds[i])
985               return false;
986             if (lowerBounds[pos] && upperBounds[i] &&
987                 lowerBounds[i] == upperBounds[i] &&
988                 otherUpperBound == lowerBounds[pos] && i < pos)
989               return false;
990             return true;
991           }))
992         irredundantExprs.push_back(e);
993     }
994   }
995 
996   // Create the map without the redundant expressions.
997   map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
998                        map.getContext());
999 }
1000 
1001 /// Simplify the map while exploiting information on the values in `operands`.
1002 //  Use "unused attribute" marker to silence warning stemming from the inability
1003 //  to see through the template expansion.
1004 static void LLVM_ATTRIBUTE_UNUSED
1005 simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
1006   assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1007   SmallVector<AffineExpr> newResults;
1008   newResults.reserve(map.getNumResults());
1009   for (AffineExpr expr : map.getResults()) {
1010     simplifyExprAndOperands(expr, map.getNumDims(), map.getNumSymbols(),
1011                             operands);
1012     newResults.push_back(expr);
1013   }
1014   map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1015                        map.getContext());
1016 }
1017 
1018 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1019 /// defining AffineApplyOp expression and operands.
1020 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1021 /// When `dimOrSymbolPosition >= dims.size()`,
1022 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1023 /// Mutate `map`,`dims` and `syms` in place as follows:
1024 ///   1. `dims` and `syms` are only appended to.
1025 ///   2. `map` dim and symbols are gradually shifted to higher positions.
1026 ///   3. Old `dim` and `sym` entries are replaced by nullptr
1027 /// This avoids the need for any bookkeeping.
1028 static LogicalResult replaceDimOrSym(AffineMap *map,
1029                                      unsigned dimOrSymbolPosition,
1030                                      SmallVectorImpl<Value> &dims,
1031                                      SmallVectorImpl<Value> &syms) {
1032   MLIRContext *ctx = map->getContext();
1033   bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1034   unsigned pos = isDimReplacement ? dimOrSymbolPosition
1035                                   : dimOrSymbolPosition - dims.size();
1036   Value &v = isDimReplacement ? dims[pos] : syms[pos];
1037   if (!v)
1038     return failure();
1039 
1040   auto affineApply = v.getDefiningOp<AffineApplyOp>();
1041   if (!affineApply)
1042     return failure();
1043 
1044   // At this point we will perform a replacement of `v`, set the entry in `dim`
1045   // or `sym` to nullptr immediately.
1046   v = nullptr;
1047 
1048   // Compute the map, dims and symbols coming from the AffineApplyOp.
1049   AffineMap composeMap = affineApply.getAffineMap();
1050   assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1051   SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1052                                      affineApply.getMapOperands().end());
1053   // Canonicalize the map to promote dims to symbols when possible. This is to
1054   // avoid generating invalid maps.
1055   canonicalizeMapAndOperands(&composeMap, &composeOperands);
1056   AffineExpr replacementExpr =
1057       composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1058   ValueRange composeDims =
1059       ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1060   ValueRange composeSyms =
1061       ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1062   AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1063                                           : getAffineSymbolExpr(pos, ctx);
1064 
1065   // Append the dims and symbols where relevant and perform the replacement.
1066   dims.append(composeDims.begin(), composeDims.end());
1067   syms.append(composeSyms.begin(), composeSyms.end());
1068   *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1069 
1070   return success();
1071 }
1072 
1073 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1074 /// iteratively. Perform canonicalization of map and operands as well as
1075 /// AffineMap simplification. `map` and `operands` are mutated in place.
1076 static void composeAffineMapAndOperands(AffineMap *map,
1077                                         SmallVectorImpl<Value> *operands) {
1078   if (map->getNumResults() == 0) {
1079     canonicalizeMapAndOperands(map, operands);
1080     *map = simplifyAffineMap(*map);
1081     return;
1082   }
1083 
1084   MLIRContext *ctx = map->getContext();
1085   SmallVector<Value, 4> dims(operands->begin(),
1086                              operands->begin() + map->getNumDims());
1087   SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1088                              operands->end());
1089 
1090   // Iterate over dims and symbols coming from AffineApplyOp and replace until
1091   // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1092   // and `syms` can only increase by construction.
1093   // The implementation uses a `while` loop to support the case of symbols
1094   // that may be constructed from dims ;this may be overkill.
1095   while (true) {
1096     bool changed = false;
1097     for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1098       if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1099         break;
1100     if (!changed)
1101       break;
1102   }
1103 
1104   // Clear operands so we can fill them anew.
1105   operands->clear();
1106 
1107   // At this point we may have introduced null operands, prune them out before
1108   // canonicalizing map and operands.
1109   unsigned nDims = 0, nSyms = 0;
1110   SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1111   dimReplacements.reserve(dims.size());
1112   symReplacements.reserve(syms.size());
1113   for (auto *container : {&dims, &syms}) {
1114     bool isDim = (container == &dims);
1115     auto &repls = isDim ? dimReplacements : symReplacements;
1116     for (const auto &en : llvm::enumerate(*container)) {
1117       Value v = en.value();
1118       if (!v) {
1119         assert(isDim ? !map->isFunctionOfDim(en.index())
1120                      : !map->isFunctionOfSymbol(en.index()) &&
1121                            "map is function of unexpected expr@pos");
1122         repls.push_back(getAffineConstantExpr(0, ctx));
1123         continue;
1124       }
1125       repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1126                             : getAffineSymbolExpr(nSyms++, ctx));
1127       operands->push_back(v);
1128     }
1129   }
1130   *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1131                                     nSyms);
1132 
1133   // Canonicalize and simplify before returning.
1134   canonicalizeMapAndOperands(map, operands);
1135   *map = simplifyAffineMap(*map);
1136 }
1137 
1138 void mlir::affine::fullyComposeAffineMapAndOperands(
1139     AffineMap *map, SmallVectorImpl<Value> *operands) {
1140   while (llvm::any_of(*operands, [](Value v) {
1141     return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1142   })) {
1143     composeAffineMapAndOperands(map, operands);
1144   }
1145 }
1146 
1147 AffineApplyOp
1148 mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
1149                                       ArrayRef<OpFoldResult> operands) {
1150   SmallVector<Value> valueOperands;
1151   map = foldAttributesIntoMap(b, map, operands, valueOperands);
1152   composeAffineMapAndOperands(&map, &valueOperands);
1153   assert(map);
1154   return b.create<AffineApplyOp>(loc, map, valueOperands);
1155 }
1156 
1157 AffineApplyOp
1158 mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
1159                                       ArrayRef<OpFoldResult> operands) {
1160   return makeComposedAffineApply(
1161       b, loc,
1162       AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}, b.getContext())
1163           .front(),
1164       operands);
1165 }
1166 
1167 /// Composes the given affine map with the given list of operands, pulling in
1168 /// the maps from any affine.apply operations that supply the operands.
1169 static void composeMultiResultAffineMap(AffineMap &map,
1170                                         SmallVectorImpl<Value> &operands) {
1171   // Compose and canonicalize each expression in the map individually because
1172   // composition only applies to single-result maps, collecting potentially
1173   // duplicate operands in a single list with shifted dimensions and symbols.
1174   SmallVector<Value> dims, symbols;
1175   SmallVector<AffineExpr> exprs;
1176   for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1177     SmallVector<Value> submapOperands(operands.begin(), operands.end());
1178     AffineMap submap = map.getSubMap({i});
1179     fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1180     canonicalizeMapAndOperands(&submap, &submapOperands);
1181     unsigned numNewDims = submap.getNumDims();
1182     submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1183     llvm::append_range(dims,
1184                        ArrayRef<Value>(submapOperands).take_front(numNewDims));
1185     llvm::append_range(symbols,
1186                        ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1187     exprs.push_back(submap.getResult(0));
1188   }
1189 
1190   // Canonicalize the map created from composed expressions to deduplicate the
1191   // dimension and symbol operands.
1192   operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1193   map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1194   canonicalizeMapAndOperands(&map, &operands);
1195 }
1196 
1197 OpFoldResult
1198 mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
1199                                             AffineMap map,
1200                                             ArrayRef<OpFoldResult> operands) {
1201   assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1202 
1203   // Create new builder without a listener, so that no notification is
1204   // triggered if the op is folded.
1205   // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1206   // workaround is no longer needed.
1207   OpBuilder newBuilder(b.getContext());
1208   newBuilder.setInsertionPoint(b.getInsertionBlock(), b.getInsertionPoint());
1209 
1210   // Create op.
1211   AffineApplyOp applyOp =
1212       makeComposedAffineApply(newBuilder, loc, map, operands);
1213 
1214   // Get constant operands.
1215   SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1216   for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1217     matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1218 
1219   // Try to fold the operation.
1220   SmallVector<OpFoldResult> foldResults;
1221   if (failed(applyOp->fold(constOperands, foldResults)) ||
1222       foldResults.empty()) {
1223     if (OpBuilder::Listener *listener = b.getListener())
1224       listener->notifyOperationInserted(applyOp, /*previous=*/{});
1225     return applyOp.getResult();
1226   }
1227 
1228   applyOp->erase();
1229   assert(foldResults.size() == 1 && "expected 1 folded result");
1230   return foldResults.front();
1231 }
1232 
1233 OpFoldResult
1234 mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
1235                                             AffineExpr expr,
1236                                             ArrayRef<OpFoldResult> operands) {
1237   return makeComposedFoldedAffineApply(
1238       b, loc,
1239       AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}, b.getContext())
1240           .front(),
1241       operands);
1242 }
1243 
1244 SmallVector<OpFoldResult>
1245 mlir::affine::makeComposedFoldedMultiResultAffineApply(
1246     OpBuilder &b, Location loc, AffineMap map,
1247     ArrayRef<OpFoldResult> operands) {
1248   return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1249                              [&](unsigned i) {
1250                                return makeComposedFoldedAffineApply(
1251                                    b, loc, map.getSubMap({i}), operands);
1252                              });
1253 }
1254 
1255 template <typename OpTy>
1256 static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map,
1257                                ArrayRef<OpFoldResult> operands) {
1258   SmallVector<Value> valueOperands;
1259   map = foldAttributesIntoMap(b, map, operands, valueOperands);
1260   composeMultiResultAffineMap(map, valueOperands);
1261   return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1262 }
1263 
1264 AffineMinOp
1265 mlir::affine::makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
1266                                     ArrayRef<OpFoldResult> operands) {
1267   return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1268 }
1269 
1270 template <typename OpTy>
1271 static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
1272                                              AffineMap map,
1273                                              ArrayRef<OpFoldResult> operands) {
1274   // Create new builder without a listener, so that no notification is
1275   // triggered if the op is folded.
1276   // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1277   // workaround is no longer needed.
1278   OpBuilder newBuilder(b.getContext());
1279   newBuilder.setInsertionPoint(b.getInsertionBlock(), b.getInsertionPoint());
1280 
1281   // Create op.
1282   auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1283 
1284   // Get constant operands.
1285   SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1286   for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1287     matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1288 
1289   // Try to fold the operation.
1290   SmallVector<OpFoldResult> foldResults;
1291   if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1292       foldResults.empty()) {
1293     if (OpBuilder::Listener *listener = b.getListener())
1294       listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1295     return minMaxOp.getResult();
1296   }
1297 
1298   minMaxOp->erase();
1299   assert(foldResults.size() == 1 && "expected 1 folded result");
1300   return foldResults.front();
1301 }
1302 
1303 OpFoldResult
1304 mlir::affine::makeComposedFoldedAffineMin(OpBuilder &b, Location loc,
1305                                           AffineMap map,
1306                                           ArrayRef<OpFoldResult> operands) {
1307   return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1308 }
1309 
1310 OpFoldResult
1311 mlir::affine::makeComposedFoldedAffineMax(OpBuilder &b, Location loc,
1312                                           AffineMap map,
1313                                           ArrayRef<OpFoldResult> operands) {
1314   return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1315 }
1316 
1317 // A symbol may appear as a dim in affine.apply operations. This function
1318 // canonicalizes dims that are valid symbols into actual symbols.
1319 template <class MapOrSet>
1320 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1321                                         SmallVectorImpl<Value> *operands) {
1322   if (!mapOrSet || operands->empty())
1323     return;
1324 
1325   assert(mapOrSet->getNumInputs() == operands->size() &&
1326          "map/set inputs must match number of operands");
1327 
1328   auto *context = mapOrSet->getContext();
1329   SmallVector<Value, 8> resultOperands;
1330   resultOperands.reserve(operands->size());
1331   SmallVector<Value, 8> remappedSymbols;
1332   remappedSymbols.reserve(operands->size());
1333   unsigned nextDim = 0;
1334   unsigned nextSym = 0;
1335   unsigned oldNumSyms = mapOrSet->getNumSymbols();
1336   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1337   for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1338     if (i < mapOrSet->getNumDims()) {
1339       if (isValidSymbol((*operands)[i])) {
1340         // This is a valid symbol that appears as a dim, canonicalize it.
1341         dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1342         remappedSymbols.push_back((*operands)[i]);
1343       } else {
1344         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1345         resultOperands.push_back((*operands)[i]);
1346       }
1347     } else {
1348       resultOperands.push_back((*operands)[i]);
1349     }
1350   }
1351 
1352   resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1353   *operands = resultOperands;
1354   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1355                                               oldNumSyms + nextSym);
1356 
1357   assert(mapOrSet->getNumInputs() == operands->size() &&
1358          "map/set inputs must match number of operands");
1359 }
1360 
1361 // Works for either an affine map or an integer set.
1362 template <class MapOrSet>
1363 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1364                                             SmallVectorImpl<Value> *operands) {
1365   static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1366                 "Argument must be either of AffineMap or IntegerSet type");
1367 
1368   if (!mapOrSet || operands->empty())
1369     return;
1370 
1371   assert(mapOrSet->getNumInputs() == operands->size() &&
1372          "map/set inputs must match number of operands");
1373 
1374   canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1375 
1376   // Check to see what dims are used.
1377   llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1378   llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1379   mapOrSet->walkExprs([&](AffineExpr expr) {
1380     if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1381       usedDims[dimExpr.getPosition()] = true;
1382     else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1383       usedSyms[symExpr.getPosition()] = true;
1384   });
1385 
1386   auto *context = mapOrSet->getContext();
1387 
1388   SmallVector<Value, 8> resultOperands;
1389   resultOperands.reserve(operands->size());
1390 
1391   llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1392   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1393   unsigned nextDim = 0;
1394   for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1395     if (usedDims[i]) {
1396       // Remap dim positions for duplicate operands.
1397       auto it = seenDims.find((*operands)[i]);
1398       if (it == seenDims.end()) {
1399         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1400         resultOperands.push_back((*operands)[i]);
1401         seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1402       } else {
1403         dimRemapping[i] = it->second;
1404       }
1405     }
1406   }
1407   llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1408   SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1409   unsigned nextSym = 0;
1410   for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1411     if (!usedSyms[i])
1412       continue;
1413     // Handle constant operands (only needed for symbolic operands since
1414     // constant operands in dimensional positions would have already been
1415     // promoted to symbolic positions above).
1416     IntegerAttr operandCst;
1417     if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1418                      m_Constant(&operandCst))) {
1419       symRemapping[i] =
1420           getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1421       continue;
1422     }
1423     // Remap symbol positions for duplicate operands.
1424     auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1425     if (it == seenSymbols.end()) {
1426       symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1427       resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1428       seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1429                                         symRemapping[i]));
1430     } else {
1431       symRemapping[i] = it->second;
1432     }
1433   }
1434   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1435                                               nextDim, nextSym);
1436   *operands = resultOperands;
1437 }
1438 
1439 void mlir::affine::canonicalizeMapAndOperands(
1440     AffineMap *map, SmallVectorImpl<Value> *operands) {
1441   canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1442 }
1443 
1444 void mlir::affine::canonicalizeSetAndOperands(
1445     IntegerSet *set, SmallVectorImpl<Value> *operands) {
1446   canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1447 }
1448 
1449 namespace {
1450 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1451 /// maps that supply results into them.
1452 ///
1453 template <typename AffineOpTy>
1454 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1455   using OpRewritePattern<AffineOpTy>::OpRewritePattern;
1456 
1457   /// Replace the affine op with another instance of it with the supplied
1458   /// map and mapOperands.
1459   void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1460                        AffineMap map, ArrayRef<Value> mapOperands) const;
1461 
1462   LogicalResult matchAndRewrite(AffineOpTy affineOp,
1463                                 PatternRewriter &rewriter) const override {
1464     static_assert(
1465         llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1466                         AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1467                         AffineVectorStoreOp, AffineVectorLoadOp>::value,
1468         "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1469         "expected");
1470     auto map = affineOp.getAffineMap();
1471     AffineMap oldMap = map;
1472     auto oldOperands = affineOp.getMapOperands();
1473     SmallVector<Value, 8> resultOperands(oldOperands);
1474     composeAffineMapAndOperands(&map, &resultOperands);
1475     canonicalizeMapAndOperands(&map, &resultOperands);
1476     simplifyMapWithOperands(map, resultOperands);
1477     if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1478                                     resultOperands.begin()))
1479       return failure();
1480 
1481     replaceAffineOp(rewriter, affineOp, map, resultOperands);
1482     return success();
1483   }
1484 };
1485 
1486 // Specialize the template to account for the different build signatures for
1487 // affine load, store, and apply ops.
1488 template <>
1489 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1490     PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1491     ArrayRef<Value> mapOperands) const {
1492   rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1493                                             mapOperands);
1494 }
1495 template <>
1496 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1497     PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1498     ArrayRef<Value> mapOperands) const {
1499   rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1500       prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1501       prefetch.getLocalityHint(), prefetch.getIsDataCache());
1502 }
1503 template <>
1504 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1505     PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1506     ArrayRef<Value> mapOperands) const {
1507   rewriter.replaceOpWithNewOp<AffineStoreOp>(
1508       store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1509 }
1510 template <>
1511 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1512     PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1513     ArrayRef<Value> mapOperands) const {
1514   rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1515       vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1516       mapOperands);
1517 }
1518 template <>
1519 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1520     PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1521     ArrayRef<Value> mapOperands) const {
1522   rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1523       vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1524       mapOperands);
1525 }
1526 
1527 // Generic version for ops that don't have extra operands.
1528 template <typename AffineOpTy>
1529 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1530     PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1531     ArrayRef<Value> mapOperands) const {
1532   rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1533 }
1534 } // namespace
1535 
1536 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1537                                                 MLIRContext *context) {
1538   results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1539 }
1540 
1541 //===----------------------------------------------------------------------===//
1542 // AffineDmaStartOp
1543 //===----------------------------------------------------------------------===//
1544 
1545 // TODO: Check that map operands are loop IVs or symbols.
1546 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1547                              Value srcMemRef, AffineMap srcMap,
1548                              ValueRange srcIndices, Value destMemRef,
1549                              AffineMap dstMap, ValueRange destIndices,
1550                              Value tagMemRef, AffineMap tagMap,
1551                              ValueRange tagIndices, Value numElements,
1552                              Value stride, Value elementsPerStride) {
1553   result.addOperands(srcMemRef);
1554   result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1555   result.addOperands(srcIndices);
1556   result.addOperands(destMemRef);
1557   result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1558   result.addOperands(destIndices);
1559   result.addOperands(tagMemRef);
1560   result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1561   result.addOperands(tagIndices);
1562   result.addOperands(numElements);
1563   if (stride) {
1564     result.addOperands({stride, elementsPerStride});
1565   }
1566 }
1567 
1568 void AffineDmaStartOp::print(OpAsmPrinter &p) {
1569   p << " " << getSrcMemRef() << '[';
1570   p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1571   p << "], " << getDstMemRef() << '[';
1572   p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1573   p << "], " << getTagMemRef() << '[';
1574   p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1575   p << "], " << getNumElements();
1576   if (isStrided()) {
1577     p << ", " << getStride();
1578     p << ", " << getNumElementsPerStride();
1579   }
1580   p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1581     << getTagMemRefType();
1582 }
1583 
1584 // Parse AffineDmaStartOp.
1585 // Ex:
1586 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1587 //     %stride, %num_elt_per_stride
1588 //       : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1589 //
1590 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1591                                     OperationState &result) {
1592   OpAsmParser::UnresolvedOperand srcMemRefInfo;
1593   AffineMapAttr srcMapAttr;
1594   SmallVector<OpAsmParser::UnresolvedOperand, 4> srcMapOperands;
1595   OpAsmParser::UnresolvedOperand dstMemRefInfo;
1596   AffineMapAttr dstMapAttr;
1597   SmallVector<OpAsmParser::UnresolvedOperand, 4> dstMapOperands;
1598   OpAsmParser::UnresolvedOperand tagMemRefInfo;
1599   AffineMapAttr tagMapAttr;
1600   SmallVector<OpAsmParser::UnresolvedOperand, 4> tagMapOperands;
1601   OpAsmParser::UnresolvedOperand numElementsInfo;
1602   SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1603 
1604   SmallVector<Type, 3> types;
1605   auto indexType = parser.getBuilder().getIndexType();
1606 
1607   // Parse and resolve the following list of operands:
1608   // *) dst memref followed by its affine maps operands (in square brackets).
1609   // *) src memref followed by its affine map operands (in square brackets).
1610   // *) tag memref followed by its affine map operands (in square brackets).
1611   // *) number of elements transferred by DMA operation.
1612   if (parser.parseOperand(srcMemRefInfo) ||
1613       parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1614                                     getSrcMapAttrStrName(),
1615                                     result.attributes) ||
1616       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1617       parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1618                                     getDstMapAttrStrName(),
1619                                     result.attributes) ||
1620       parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1621       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1622                                     getTagMapAttrStrName(),
1623                                     result.attributes) ||
1624       parser.parseComma() || parser.parseOperand(numElementsInfo))
1625     return failure();
1626 
1627   // Parse optional stride and elements per stride.
1628   if (parser.parseTrailingOperandList(strideInfo))
1629     return failure();
1630 
1631   if (!strideInfo.empty() && strideInfo.size() != 2) {
1632     return parser.emitError(parser.getNameLoc(),
1633                             "expected two stride related operands");
1634   }
1635   bool isStrided = strideInfo.size() == 2;
1636 
1637   if (parser.parseColonTypeList(types))
1638     return failure();
1639 
1640   if (types.size() != 3)
1641     return parser.emitError(parser.getNameLoc(), "expected three types");
1642 
1643   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1644       parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1645       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1646       parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1647       parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1648       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1649       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1650     return failure();
1651 
1652   if (isStrided) {
1653     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1654       return failure();
1655   }
1656 
1657   // Check that src/dst/tag operand counts match their map.numInputs.
1658   if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1659       dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1660       tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1661     return parser.emitError(parser.getNameLoc(),
1662                             "memref operand count not equal to map.numInputs");
1663   return success();
1664 }
1665 
1666 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1667   if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1668     return emitOpError("expected DMA source to be of memref type");
1669   if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1670     return emitOpError("expected DMA destination to be of memref type");
1671   if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1672     return emitOpError("expected DMA tag to be of memref type");
1673 
1674   unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1675                               getDstMap().getNumInputs() +
1676                               getTagMap().getNumInputs();
1677   if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1678       getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1679     return emitOpError("incorrect number of operands");
1680   }
1681 
1682   Region *scope = getAffineScope(*this);
1683   for (auto idx : getSrcIndices()) {
1684     if (!idx.getType().isIndex())
1685       return emitOpError("src index to dma_start must have 'index' type");
1686     if (!isValidAffineIndexOperand(idx, scope))
1687       return emitOpError(
1688           "src index must be a valid dimension or symbol identifier");
1689   }
1690   for (auto idx : getDstIndices()) {
1691     if (!idx.getType().isIndex())
1692       return emitOpError("dst index to dma_start must have 'index' type");
1693     if (!isValidAffineIndexOperand(idx, scope))
1694       return emitOpError(
1695           "dst index must be a valid dimension or symbol identifier");
1696   }
1697   for (auto idx : getTagIndices()) {
1698     if (!idx.getType().isIndex())
1699       return emitOpError("tag index to dma_start must have 'index' type");
1700     if (!isValidAffineIndexOperand(idx, scope))
1701       return emitOpError(
1702           "tag index must be a valid dimension or symbol identifier");
1703   }
1704   return success();
1705 }
1706 
1707 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1708                                      SmallVectorImpl<OpFoldResult> &results) {
1709   /// dma_start(memrefcast) -> dma_start
1710   return memref::foldMemRefCast(*this);
1711 }
1712 
1713 void AffineDmaStartOp::getEffects(
1714     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1715         &effects) {
1716   effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
1717                        SideEffects::DefaultResource::get());
1718   effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
1719                        SideEffects::DefaultResource::get());
1720   effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1721                        SideEffects::DefaultResource::get());
1722 }
1723 
1724 //===----------------------------------------------------------------------===//
1725 // AffineDmaWaitOp
1726 //===----------------------------------------------------------------------===//
1727 
1728 // TODO: Check that map operands are loop IVs or symbols.
1729 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1730                             Value tagMemRef, AffineMap tagMap,
1731                             ValueRange tagIndices, Value numElements) {
1732   result.addOperands(tagMemRef);
1733   result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1734   result.addOperands(tagIndices);
1735   result.addOperands(numElements);
1736 }
1737 
1738 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1739   p << " " << getTagMemRef() << '[';
1740   SmallVector<Value, 2> operands(getTagIndices());
1741   p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1742   p << "], ";
1743   p.printOperand(getNumElements());
1744   p << " : " << getTagMemRef().getType();
1745 }
1746 
1747 // Parse AffineDmaWaitOp.
1748 // Eg:
1749 //   affine.dma_wait %tag[%index], %num_elements
1750 //     : memref<1 x i32, (d0) -> (d0), 4>
1751 //
1752 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1753                                    OperationState &result) {
1754   OpAsmParser::UnresolvedOperand tagMemRefInfo;
1755   AffineMapAttr tagMapAttr;
1756   SmallVector<OpAsmParser::UnresolvedOperand, 2> tagMapOperands;
1757   Type type;
1758   auto indexType = parser.getBuilder().getIndexType();
1759   OpAsmParser::UnresolvedOperand numElementsInfo;
1760 
1761   // Parse tag memref, its map operands, and dma size.
1762   if (parser.parseOperand(tagMemRefInfo) ||
1763       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1764                                     getTagMapAttrStrName(),
1765                                     result.attributes) ||
1766       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1767       parser.parseColonType(type) ||
1768       parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1769       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1770       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1771     return failure();
1772 
1773   if (!llvm::isa<MemRefType>(type))
1774     return parser.emitError(parser.getNameLoc(),
1775                             "expected tag to be of memref type");
1776 
1777   if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1778     return parser.emitError(parser.getNameLoc(),
1779                             "tag memref operand count != to map.numInputs");
1780   return success();
1781 }
1782 
1783 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1784   if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1785     return emitOpError("expected DMA tag to be of memref type");
1786   Region *scope = getAffineScope(*this);
1787   for (auto idx : getTagIndices()) {
1788     if (!idx.getType().isIndex())
1789       return emitOpError("index to dma_wait must have 'index' type");
1790     if (!isValidAffineIndexOperand(idx, scope))
1791       return emitOpError(
1792           "index must be a valid dimension or symbol identifier");
1793   }
1794   return success();
1795 }
1796 
1797 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1798                                     SmallVectorImpl<OpFoldResult> &results) {
1799   /// dma_wait(memrefcast) -> dma_wait
1800   return memref::foldMemRefCast(*this);
1801 }
1802 
1803 void AffineDmaWaitOp::getEffects(
1804     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1805         &effects) {
1806   effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1807                        SideEffects::DefaultResource::get());
1808 }
1809 
1810 //===----------------------------------------------------------------------===//
1811 // AffineForOp
1812 //===----------------------------------------------------------------------===//
1813 
1814 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1815 /// bodyBuilder are empty/null, we include default terminator op.
1816 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1817                         ValueRange lbOperands, AffineMap lbMap,
1818                         ValueRange ubOperands, AffineMap ubMap, int64_t step,
1819                         ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1820   assert(((!lbMap && lbOperands.empty()) ||
1821           lbOperands.size() == lbMap.getNumInputs()) &&
1822          "lower bound operand count does not match the affine map");
1823   assert(((!ubMap && ubOperands.empty()) ||
1824           ubOperands.size() == ubMap.getNumInputs()) &&
1825          "upper bound operand count does not match the affine map");
1826   assert(step > 0 && "step has to be a positive integer constant");
1827 
1828   OpBuilder::InsertionGuard guard(builder);
1829 
1830   // Set variadic segment sizes.
1831   result.addAttribute(
1832       getOperandSegmentSizeAttr(),
1833       builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1834                                     static_cast<int32_t>(ubOperands.size()),
1835                                     static_cast<int32_t>(iterArgs.size())}));
1836 
1837   for (Value val : iterArgs)
1838     result.addTypes(val.getType());
1839 
1840   // Add an attribute for the step.
1841   result.addAttribute(getStepAttrName(result.name),
1842                       builder.getIntegerAttr(builder.getIndexType(), step));
1843 
1844   // Add the lower bound.
1845   result.addAttribute(getLowerBoundMapAttrName(result.name),
1846                       AffineMapAttr::get(lbMap));
1847   result.addOperands(lbOperands);
1848 
1849   // Add the upper bound.
1850   result.addAttribute(getUpperBoundMapAttrName(result.name),
1851                       AffineMapAttr::get(ubMap));
1852   result.addOperands(ubOperands);
1853 
1854   result.addOperands(iterArgs);
1855   // Create a region and a block for the body.  The argument of the region is
1856   // the loop induction variable.
1857   Region *bodyRegion = result.addRegion();
1858   Block *bodyBlock = builder.createBlock(bodyRegion);
1859   Value inductionVar =
1860       bodyBlock->addArgument(builder.getIndexType(), result.location);
1861   for (Value val : iterArgs)
1862     bodyBlock->addArgument(val.getType(), val.getLoc());
1863 
1864   // Create the default terminator if the builder is not provided and if the
1865   // iteration arguments are not provided. Otherwise, leave this to the caller
1866   // because we don't know which values to return from the loop.
1867   if (iterArgs.empty() && !bodyBuilder) {
1868     ensureTerminator(*bodyRegion, builder, result.location);
1869   } else if (bodyBuilder) {
1870     OpBuilder::InsertionGuard guard(builder);
1871     builder.setInsertionPointToStart(bodyBlock);
1872     bodyBuilder(builder, result.location, inductionVar,
1873                 bodyBlock->getArguments().drop_front());
1874   }
1875 }
1876 
1877 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1878                         int64_t ub, int64_t step, ValueRange iterArgs,
1879                         BodyBuilderFn bodyBuilder) {
1880   auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1881   auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1882   return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1883                bodyBuilder);
1884 }
1885 
1886 LogicalResult AffineForOp::verifyRegions() {
1887   // Check that the body defines as single block argument for the induction
1888   // variable.
1889   auto *body = getBody();
1890   if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1891     return emitOpError("expected body to have a single index argument for the "
1892                        "induction variable");
1893 
1894   // Verify that the bound operands are valid dimension/symbols.
1895   /// Lower bound.
1896   if (getLowerBoundMap().getNumInputs() > 0)
1897     if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundOperands(),
1898                                              getLowerBoundMap().getNumDims())))
1899       return failure();
1900   /// Upper bound.
1901   if (getUpperBoundMap().getNumInputs() > 0)
1902     if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundOperands(),
1903                                              getUpperBoundMap().getNumDims())))
1904       return failure();
1905 
1906   unsigned opNumResults = getNumResults();
1907   if (opNumResults == 0)
1908     return success();
1909 
1910   // If ForOp defines values, check that the number and types of the defined
1911   // values match ForOp initial iter operands and backedge basic block
1912   // arguments.
1913   if (getNumIterOperands() != opNumResults)
1914     return emitOpError(
1915         "mismatch between the number of loop-carried values and results");
1916   if (getNumRegionIterArgs() != opNumResults)
1917     return emitOpError(
1918         "mismatch between the number of basic block args and results");
1919 
1920   return success();
1921 }
1922 
1923 /// Parse a for operation loop bounds.
1924 static ParseResult parseBound(bool isLower, OperationState &result,
1925                               OpAsmParser &p) {
1926   // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1927   // the map has multiple results.
1928   bool failedToParsedMinMax =
1929       failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1930 
1931   auto &builder = p.getBuilder();
1932   auto boundAttrStrName =
1933       isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
1934               : AffineForOp::getUpperBoundMapAttrName(result.name);
1935 
1936   // Parse ssa-id as identity map.
1937   SmallVector<OpAsmParser::UnresolvedOperand, 1> boundOpInfos;
1938   if (p.parseOperandList(boundOpInfos))
1939     return failure();
1940 
1941   if (!boundOpInfos.empty()) {
1942     // Check that only one operand was parsed.
1943     if (boundOpInfos.size() > 1)
1944       return p.emitError(p.getNameLoc(),
1945                          "expected only one loop bound operand");
1946 
1947     // TODO: improve error message when SSA value is not of index type.
1948     // Currently it is 'use of value ... expects different type than prior uses'
1949     if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1950                          result.operands))
1951       return failure();
1952 
1953     // Create an identity map using symbol id. This representation is optimized
1954     // for storage. Analysis passes may expand it into a multi-dimensional map
1955     // if desired.
1956     AffineMap map = builder.getSymbolIdentityMap();
1957     result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1958     return success();
1959   }
1960 
1961   // Get the attribute location.
1962   SMLoc attrLoc = p.getCurrentLocation();
1963 
1964   Attribute boundAttr;
1965   if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1966                        result.attributes))
1967     return failure();
1968 
1969   // Parse full form - affine map followed by dim and symbol list.
1970   if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1971     unsigned currentNumOperands = result.operands.size();
1972     unsigned numDims;
1973     if (parseDimAndSymbolList(p, result.operands, numDims))
1974       return failure();
1975 
1976     auto map = affineMapAttr.getValue();
1977     if (map.getNumDims() != numDims)
1978       return p.emitError(
1979           p.getNameLoc(),
1980           "dim operand count and affine map dim count must match");
1981 
1982     unsigned numDimAndSymbolOperands =
1983         result.operands.size() - currentNumOperands;
1984     if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1985       return p.emitError(
1986           p.getNameLoc(),
1987           "symbol operand count and affine map symbol count must match");
1988 
1989     // If the map has multiple results, make sure that we parsed the min/max
1990     // prefix.
1991     if (map.getNumResults() > 1 && failedToParsedMinMax) {
1992       if (isLower) {
1993         return p.emitError(attrLoc, "lower loop bound affine map with "
1994                                     "multiple results requires 'max' prefix");
1995       }
1996       return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1997                                   "results requires 'min' prefix");
1998     }
1999     return success();
2000   }
2001 
2002   // Parse custom assembly form.
2003   if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2004     result.attributes.pop_back();
2005     result.addAttribute(
2006         boundAttrStrName,
2007         AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2008     return success();
2009   }
2010 
2011   return p.emitError(
2012       p.getNameLoc(),
2013       "expected valid affine map representation for loop bounds");
2014 }
2015 
2016 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2017   auto &builder = parser.getBuilder();
2018   OpAsmParser::Argument inductionVariable;
2019   inductionVariable.type = builder.getIndexType();
2020   // Parse the induction variable followed by '='.
2021   if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2022     return failure();
2023 
2024   // Parse loop bounds.
2025   int64_t numOperands = result.operands.size();
2026   if (parseBound(/*isLower=*/true, result, parser))
2027     return failure();
2028   int64_t numLbOperands = result.operands.size() - numOperands;
2029   if (parser.parseKeyword("to", " between bounds"))
2030     return failure();
2031   numOperands = result.operands.size();
2032   if (parseBound(/*isLower=*/false, result, parser))
2033     return failure();
2034   int64_t numUbOperands = result.operands.size() - numOperands;
2035 
2036   // Parse the optional loop step, we default to 1 if one is not present.
2037   if (parser.parseOptionalKeyword("step")) {
2038     result.addAttribute(
2039         getStepAttrName(result.name),
2040         builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2041   } else {
2042     SMLoc stepLoc = parser.getCurrentLocation();
2043     IntegerAttr stepAttr;
2044     if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2045                               getStepAttrName(result.name).data(),
2046                               result.attributes))
2047       return failure();
2048 
2049     if (stepAttr.getValue().isNegative())
2050       return parser.emitError(
2051           stepLoc,
2052           "expected step to be representable as a positive signed integer");
2053   }
2054 
2055   // Parse the optional initial iteration arguments.
2056   SmallVector<OpAsmParser::Argument, 4> regionArgs;
2057   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2058 
2059   // Induction variable.
2060   regionArgs.push_back(inductionVariable);
2061 
2062   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2063     // Parse assignment list and results type list.
2064     if (parser.parseAssignmentList(regionArgs, operands) ||
2065         parser.parseArrowTypeList(result.types))
2066       return failure();
2067     // Resolve input operands.
2068     for (auto argOperandType :
2069          llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2070       Type type = std::get<2>(argOperandType);
2071       std::get<0>(argOperandType).type = type;
2072       if (parser.resolveOperand(std::get<1>(argOperandType), type,
2073                                 result.operands))
2074         return failure();
2075     }
2076   }
2077 
2078   result.addAttribute(
2079       getOperandSegmentSizeAttr(),
2080       builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2081                                     static_cast<int32_t>(numUbOperands),
2082                                     static_cast<int32_t>(operands.size())}));
2083 
2084   // Parse the body region.
2085   Region *body = result.addRegion();
2086   if (regionArgs.size() != result.types.size() + 1)
2087     return parser.emitError(
2088         parser.getNameLoc(),
2089         "mismatch between the number of loop-carried values and results");
2090   if (parser.parseRegion(*body, regionArgs))
2091     return failure();
2092 
2093   AffineForOp::ensureTerminator(*body, builder, result.location);
2094 
2095   // Parse the optional attribute list.
2096   return parser.parseOptionalAttrDict(result.attributes);
2097 }
2098 
2099 static void printBound(AffineMapAttr boundMap,
2100                        Operation::operand_range boundOperands,
2101                        const char *prefix, OpAsmPrinter &p) {
2102   AffineMap map = boundMap.getValue();
2103 
2104   // Check if this bound should be printed using custom assembly form.
2105   // The decision to restrict printing custom assembly form to trivial cases
2106   // comes from the will to roundtrip MLIR binary -> text -> binary in a
2107   // lossless way.
2108   // Therefore, custom assembly form parsing and printing is only supported for
2109   // zero-operand constant maps and single symbol operand identity maps.
2110   if (map.getNumResults() == 1) {
2111     AffineExpr expr = map.getResult(0);
2112 
2113     // Print constant bound.
2114     if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2115       if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2116         p << constExpr.getValue();
2117         return;
2118       }
2119     }
2120 
2121     // Print bound that consists of a single SSA symbol if the map is over a
2122     // single symbol.
2123     if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2124       if (dyn_cast<AffineSymbolExpr>(expr)) {
2125         p.printOperand(*boundOperands.begin());
2126         return;
2127       }
2128     }
2129   } else {
2130     // Map has multiple results. Print 'min' or 'max' prefix.
2131     p << prefix << ' ';
2132   }
2133 
2134   // Print the map and its operands.
2135   p << boundMap;
2136   printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2137                         map.getNumDims(), p);
2138 }
2139 
2140 unsigned AffineForOp::getNumIterOperands() {
2141   AffineMap lbMap = getLowerBoundMapAttr().getValue();
2142   AffineMap ubMap = getUpperBoundMapAttr().getValue();
2143 
2144   return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2145 }
2146 
2147 std::optional<MutableArrayRef<OpOperand>>
2148 AffineForOp::getYieldedValuesMutable() {
2149   return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2150 }
2151 
2152 void AffineForOp::print(OpAsmPrinter &p) {
2153   p << ' ';
2154   p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2155                         /*omitType=*/true);
2156   p << " = ";
2157   printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2158   p << " to ";
2159   printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2160 
2161   if (getStepAsInt() != 1)
2162     p << " step " << getStepAsInt();
2163 
2164   bool printBlockTerminators = false;
2165   if (getNumIterOperands() > 0) {
2166     p << " iter_args(";
2167     auto regionArgs = getRegionIterArgs();
2168     auto operands = getInits();
2169 
2170     llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2171       p << std::get<0>(it) << " = " << std::get<1>(it);
2172     });
2173     p << ") -> (" << getResultTypes() << ")";
2174     printBlockTerminators = true;
2175   }
2176 
2177   p << ' ';
2178   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2179                 printBlockTerminators);
2180   p.printOptionalAttrDict(
2181       (*this)->getAttrs(),
2182       /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2183                        getUpperBoundMapAttrName(getOperation()->getName()),
2184                        getStepAttrName(getOperation()->getName()),
2185                        getOperandSegmentSizeAttr()});
2186 }
2187 
2188 /// Fold the constant bounds of a loop.
2189 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2190   auto foldLowerOrUpperBound = [&forOp](bool lower) {
2191     // Check to see if each of the operands is the result of a constant.  If
2192     // so, get the value.  If not, ignore it.
2193     SmallVector<Attribute, 8> operandConstants;
2194     auto boundOperands =
2195         lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2196     for (auto operand : boundOperands) {
2197       Attribute operandCst;
2198       matchPattern(operand, m_Constant(&operandCst));
2199       operandConstants.push_back(operandCst);
2200     }
2201 
2202     AffineMap boundMap =
2203         lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2204     assert(boundMap.getNumResults() >= 1 &&
2205            "bound maps should have at least one result");
2206     SmallVector<Attribute, 4> foldedResults;
2207     if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2208       return failure();
2209 
2210     // Compute the max or min as applicable over the results.
2211     assert(!foldedResults.empty() && "bounds should have at least one result");
2212     auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2213     for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2214       auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2215       maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2216                        : llvm::APIntOps::smin(maxOrMin, foldedResult);
2217     }
2218     lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2219           : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2220     return success();
2221   };
2222 
2223   // Try to fold the lower bound.
2224   bool folded = false;
2225   if (!forOp.hasConstantLowerBound())
2226     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2227 
2228   // Try to fold the upper bound.
2229   if (!forOp.hasConstantUpperBound())
2230     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2231   return success(folded);
2232 }
2233 
2234 /// Canonicalize the bounds of the given loop.
2235 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2236   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2237   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2238 
2239   auto lbMap = forOp.getLowerBoundMap();
2240   auto ubMap = forOp.getUpperBoundMap();
2241   auto prevLbMap = lbMap;
2242   auto prevUbMap = ubMap;
2243 
2244   composeAffineMapAndOperands(&lbMap, &lbOperands);
2245   canonicalizeMapAndOperands(&lbMap, &lbOperands);
2246   simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2247   simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2248   lbMap = removeDuplicateExprs(lbMap);
2249 
2250   composeAffineMapAndOperands(&ubMap, &ubOperands);
2251   canonicalizeMapAndOperands(&ubMap, &ubOperands);
2252   ubMap = removeDuplicateExprs(ubMap);
2253 
2254   // Any canonicalization change always leads to updated map(s).
2255   if (lbMap == prevLbMap && ubMap == prevUbMap)
2256     return failure();
2257 
2258   if (lbMap != prevLbMap)
2259     forOp.setLowerBound(lbOperands, lbMap);
2260   if (ubMap != prevUbMap)
2261     forOp.setUpperBound(ubOperands, ubMap);
2262   return success();
2263 }
2264 
2265 namespace {
2266 /// Returns constant trip count in trivial cases.
2267 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2268   int64_t step = forOp.getStepAsInt();
2269   if (!forOp.hasConstantBounds() || step <= 0)
2270     return std::nullopt;
2271   int64_t lb = forOp.getConstantLowerBound();
2272   int64_t ub = forOp.getConstantUpperBound();
2273   return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2274 }
2275 
2276 /// This is a pattern to fold trivially empty loop bodies.
2277 /// TODO: This should be moved into the folding hook.
2278 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2279   using OpRewritePattern<AffineForOp>::OpRewritePattern;
2280 
2281   LogicalResult matchAndRewrite(AffineForOp forOp,
2282                                 PatternRewriter &rewriter) const override {
2283     // Check that the body only contains a yield.
2284     if (!llvm::hasSingleElement(*forOp.getBody()))
2285       return failure();
2286     if (forOp.getNumResults() == 0)
2287       return success();
2288     std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2289     if (tripCount && *tripCount == 0) {
2290       // The initial values of the iteration arguments would be the op's
2291       // results.
2292       rewriter.replaceOp(forOp, forOp.getInits());
2293       return success();
2294     }
2295     SmallVector<Value, 4> replacements;
2296     auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2297     auto iterArgs = forOp.getRegionIterArgs();
2298     bool hasValDefinedOutsideLoop = false;
2299     bool iterArgsNotInOrder = false;
2300     for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2301       Value val = yieldOp.getOperand(i);
2302       auto *iterArgIt = llvm::find(iterArgs, val);
2303       if (iterArgIt == iterArgs.end()) {
2304         // `val` is defined outside of the loop.
2305         assert(forOp.isDefinedOutsideOfLoop(val) &&
2306                "must be defined outside of the loop");
2307         hasValDefinedOutsideLoop = true;
2308         replacements.push_back(val);
2309       } else {
2310         unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2311         if (pos != i)
2312           iterArgsNotInOrder = true;
2313         replacements.push_back(forOp.getInits()[pos]);
2314       }
2315     }
2316     // Bail out when the trip count is unknown and the loop returns any value
2317     // defined outside of the loop or any iterArg out of order.
2318     if (!tripCount.has_value() &&
2319         (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2320       return failure();
2321     // Bail out when the loop iterates more than once and it returns any iterArg
2322     // out of order.
2323     if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2324       return failure();
2325     rewriter.replaceOp(forOp, replacements);
2326     return success();
2327   }
2328 };
2329 } // namespace
2330 
2331 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2332                                               MLIRContext *context) {
2333   results.add<AffineForEmptyLoopFolder>(context);
2334 }
2335 
2336 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2337   assert((point.isParent() || point == getRegion()) && "invalid region point");
2338 
2339   // The initial operands map to the loop arguments after the induction
2340   // variable or are forwarded to the results when the trip count is zero.
2341   return getInits();
2342 }
2343 
2344 void AffineForOp::getSuccessorRegions(
2345     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2346   assert((point.isParent() || point == getRegion()) && "expected loop region");
2347   // The loop may typically branch back to its body or to the parent operation.
2348   // If the predecessor is the parent op and the trip count is known to be at
2349   // least one, branch into the body using the iterator arguments. And in cases
2350   // we know the trip count is zero, it can only branch back to its parent.
2351   std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2352   if (point.isParent() && tripCount.has_value()) {
2353     if (tripCount.value() > 0) {
2354       regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2355       return;
2356     }
2357     if (tripCount.value() == 0) {
2358       regions.push_back(RegionSuccessor(getResults()));
2359       return;
2360     }
2361   }
2362 
2363   // From the loop body, if the trip count is one, we can only branch back to
2364   // the parent.
2365   if (!point.isParent() && tripCount && *tripCount == 1) {
2366     regions.push_back(RegionSuccessor(getResults()));
2367     return;
2368   }
2369 
2370   // In all other cases, the loop may branch back to itself or the parent
2371   // operation.
2372   regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2373   regions.push_back(RegionSuccessor(getResults()));
2374 }
2375 
2376 /// Returns true if the affine.for has zero iterations in trivial cases.
2377 static bool hasTrivialZeroTripCount(AffineForOp op) {
2378   std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2379   return tripCount && *tripCount == 0;
2380 }
2381 
2382 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2383                                 SmallVectorImpl<OpFoldResult> &results) {
2384   bool folded = succeeded(foldLoopBounds(*this));
2385   folded |= succeeded(canonicalizeLoopBounds(*this));
2386   if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2387     // The initial values of the loop-carried variables (iter_args) are the
2388     // results of the op. But this must be avoided for an affine.for op that
2389     // does not return any results. Since ops that do not return results cannot
2390     // be folded away, we would enter an infinite loop of folds on the same
2391     // affine.for op.
2392     results.assign(getInits().begin(), getInits().end());
2393     folded = true;
2394   }
2395   return success(folded);
2396 }
2397 
2398 AffineBound AffineForOp::getLowerBound() {
2399   return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2400 }
2401 
2402 AffineBound AffineForOp::getUpperBound() {
2403   return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2404 }
2405 
2406 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2407   assert(lbOperands.size() == map.getNumInputs());
2408   assert(map.getNumResults() >= 1 && "bound map has at least one result");
2409   getLowerBoundOperandsMutable().assign(lbOperands);
2410   setLowerBoundMap(map);
2411 }
2412 
2413 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2414   assert(ubOperands.size() == map.getNumInputs());
2415   assert(map.getNumResults() >= 1 && "bound map has at least one result");
2416   getUpperBoundOperandsMutable().assign(ubOperands);
2417   setUpperBoundMap(map);
2418 }
2419 
2420 bool AffineForOp::hasConstantLowerBound() {
2421   return getLowerBoundMap().isSingleConstant();
2422 }
2423 
2424 bool AffineForOp::hasConstantUpperBound() {
2425   return getUpperBoundMap().isSingleConstant();
2426 }
2427 
2428 int64_t AffineForOp::getConstantLowerBound() {
2429   return getLowerBoundMap().getSingleConstantResult();
2430 }
2431 
2432 int64_t AffineForOp::getConstantUpperBound() {
2433   return getUpperBoundMap().getSingleConstantResult();
2434 }
2435 
2436 void AffineForOp::setConstantLowerBound(int64_t value) {
2437   setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2438 }
2439 
2440 void AffineForOp::setConstantUpperBound(int64_t value) {
2441   setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2442 }
2443 
2444 AffineForOp::operand_range AffineForOp::getControlOperands() {
2445   return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2446                                getUpperBoundOperands().size()};
2447 }
2448 
2449 bool AffineForOp::matchingBoundOperandList() {
2450   auto lbMap = getLowerBoundMap();
2451   auto ubMap = getUpperBoundMap();
2452   if (lbMap.getNumDims() != ubMap.getNumDims() ||
2453       lbMap.getNumSymbols() != ubMap.getNumSymbols())
2454     return false;
2455 
2456   unsigned numOperands = lbMap.getNumInputs();
2457   for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2458     // Compare Value 's.
2459     if (getOperand(i) != getOperand(numOperands + i))
2460       return false;
2461   }
2462   return true;
2463 }
2464 
2465 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2466 
2467 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2468   return SmallVector<Value>{getInductionVar()};
2469 }
2470 
2471 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2472   if (!hasConstantLowerBound())
2473     return std::nullopt;
2474   OpBuilder b(getContext());
2475   return SmallVector<OpFoldResult>{
2476       OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2477 }
2478 
2479 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2480   OpBuilder b(getContext());
2481   return SmallVector<OpFoldResult>{
2482       OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2483 }
2484 
2485 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2486   if (!hasConstantUpperBound())
2487     return {};
2488   OpBuilder b(getContext());
2489   return SmallVector<OpFoldResult>{
2490       OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2491 }
2492 
2493 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2494     RewriterBase &rewriter, ValueRange newInitOperands,
2495     bool replaceInitOperandUsesInLoop,
2496     const NewYieldValuesFn &newYieldValuesFn) {
2497   // Create a new loop before the existing one, with the extra operands.
2498   OpBuilder::InsertionGuard g(rewriter);
2499   rewriter.setInsertionPoint(getOperation());
2500   auto inits = llvm::to_vector(getInits());
2501   inits.append(newInitOperands.begin(), newInitOperands.end());
2502   AffineForOp newLoop = rewriter.create<AffineForOp>(
2503       getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2504       getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2505 
2506   // Generate the new yield values and append them to the scf.yield operation.
2507   auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2508   ArrayRef<BlockArgument> newIterArgs =
2509       newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2510   {
2511     OpBuilder::InsertionGuard g(rewriter);
2512     rewriter.setInsertionPoint(yieldOp);
2513     SmallVector<Value> newYieldedValues =
2514         newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2515     assert(newInitOperands.size() == newYieldedValues.size() &&
2516            "expected as many new yield values as new iter operands");
2517     rewriter.modifyOpInPlace(yieldOp, [&]() {
2518       yieldOp.getOperandsMutable().append(newYieldedValues);
2519     });
2520   }
2521 
2522   // Move the loop body to the new op.
2523   rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2524                        newLoop.getBody()->getArguments().take_front(
2525                            getBody()->getNumArguments()));
2526 
2527   if (replaceInitOperandUsesInLoop) {
2528     // Replace all uses of `newInitOperands` with the corresponding basic block
2529     // arguments.
2530     for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2531       rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2532                                  [&](OpOperand &use) {
2533                                    Operation *user = use.getOwner();
2534                                    return newLoop->isProperAncestor(user);
2535                                  });
2536     }
2537   }
2538 
2539   // Replace the old loop.
2540   rewriter.replaceOp(getOperation(),
2541                      newLoop->getResults().take_front(getNumResults()));
2542   return cast<LoopLikeOpInterface>(newLoop.getOperation());
2543 }
2544 
2545 Speculation::Speculatability AffineForOp::getSpeculatability() {
2546   // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2547   // Start and End.
2548   //
2549   // For Step != 1, the loop may not terminate.  We can add more smarts here if
2550   // needed.
2551   return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2552                              : Speculation::NotSpeculatable;
2553 }
2554 
2555 /// Returns true if the provided value is the induction variable of a
2556 /// AffineForOp.
2557 bool mlir::affine::isAffineForInductionVar(Value val) {
2558   return getForInductionVarOwner(val) != AffineForOp();
2559 }
2560 
2561 bool mlir::affine::isAffineParallelInductionVar(Value val) {
2562   return getAffineParallelInductionVarOwner(val) != nullptr;
2563 }
2564 
2565 bool mlir::affine::isAffineInductionVar(Value val) {
2566   return isAffineForInductionVar(val) || isAffineParallelInductionVar(val);
2567 }
2568 
2569 AffineForOp mlir::affine::getForInductionVarOwner(Value val) {
2570   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2571   if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2572     return AffineForOp();
2573   if (auto forOp =
2574           ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2575     // Check to make sure `val` is the induction variable, not an iter_arg.
2576     return forOp.getInductionVar() == val ? forOp : AffineForOp();
2577   return AffineForOp();
2578 }
2579 
2580 AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
2581   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2582   if (!ivArg || !ivArg.getOwner())
2583     return nullptr;
2584   Operation *containingOp = ivArg.getOwner()->getParentOp();
2585   auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2586   if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2587     return parallelOp;
2588   return nullptr;
2589 }
2590 
2591 /// Extracts the induction variables from a list of AffineForOps and returns
2592 /// them.
2593 void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
2594                                            SmallVectorImpl<Value> *ivs) {
2595   ivs->reserve(forInsts.size());
2596   for (auto forInst : forInsts)
2597     ivs->push_back(forInst.getInductionVar());
2598 }
2599 
2600 void mlir::affine::extractInductionVars(ArrayRef<mlir::Operation *> affineOps,
2601                                         SmallVectorImpl<mlir::Value> &ivs) {
2602   ivs.reserve(affineOps.size());
2603   for (Operation *op : affineOps) {
2604     // Add constraints from forOp's bounds.
2605     if (auto forOp = dyn_cast<AffineForOp>(op))
2606       ivs.push_back(forOp.getInductionVar());
2607     else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2608       for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2609         ivs.push_back(parallelOp.getBody()->getArgument(i));
2610   }
2611 }
2612 
2613 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2614 /// operations.
2615 template <typename BoundListTy, typename LoopCreatorTy>
2616 static void buildAffineLoopNestImpl(
2617     OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2618     ArrayRef<int64_t> steps,
2619     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2620     LoopCreatorTy &&loopCreatorFn) {
2621   assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2622   assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2623 
2624   // If there are no loops to be constructed, construct the body anyway.
2625   OpBuilder::InsertionGuard guard(builder);
2626   if (lbs.empty()) {
2627     if (bodyBuilderFn)
2628       bodyBuilderFn(builder, loc, ValueRange());
2629     return;
2630   }
2631 
2632   // Create the loops iteratively and store the induction variables.
2633   SmallVector<Value, 4> ivs;
2634   ivs.reserve(lbs.size());
2635   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2636     // Callback for creating the loop body, always creates the terminator.
2637     auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2638                         ValueRange iterArgs) {
2639       ivs.push_back(iv);
2640       // In the innermost loop, call the body builder.
2641       if (i == e - 1 && bodyBuilderFn) {
2642         OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2643         bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2644       }
2645       nestedBuilder.create<AffineYieldOp>(nestedLoc);
2646     };
2647 
2648     // Delegate actual loop creation to the callback in order to dispatch
2649     // between constant- and variable-bound loops.
2650     auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2651     builder.setInsertionPointToStart(loop.getBody());
2652   }
2653 }
2654 
2655 /// Creates an affine loop from the bounds known to be constants.
2656 static AffineForOp
2657 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
2658                              int64_t ub, int64_t step,
2659                              AffineForOp::BodyBuilderFn bodyBuilderFn) {
2660   return builder.create<AffineForOp>(loc, lb, ub, step,
2661                                      /*iterArgs=*/std::nullopt, bodyBuilderFn);
2662 }
2663 
2664 /// Creates an affine loop from the bounds that may or may not be constants.
2665 static AffineForOp
2666 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
2667                           int64_t step,
2668                           AffineForOp::BodyBuilderFn bodyBuilderFn) {
2669   std::optional<int64_t> lbConst = getConstantIntValue(lb);
2670   std::optional<int64_t> ubConst = getConstantIntValue(ub);
2671   if (lbConst && ubConst)
2672     return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2673                                         ubConst.value(), step, bodyBuilderFn);
2674   return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2675                                      builder.getDimIdentityMap(), step,
2676                                      /*iterArgs=*/std::nullopt, bodyBuilderFn);
2677 }
2678 
2679 void mlir::affine::buildAffineLoopNest(
2680     OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2681     ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
2682     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2683   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2684                           buildAffineLoopFromConstants);
2685 }
2686 
2687 void mlir::affine::buildAffineLoopNest(
2688     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2689     ArrayRef<int64_t> steps,
2690     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2691   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2692                           buildAffineLoopFromValues);
2693 }
2694 
2695 //===----------------------------------------------------------------------===//
2696 // AffineIfOp
2697 //===----------------------------------------------------------------------===//
2698 
2699 namespace {
2700 /// Remove else blocks that have nothing other than a zero value yield.
2701 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2702   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2703 
2704   LogicalResult matchAndRewrite(AffineIfOp ifOp,
2705                                 PatternRewriter &rewriter) const override {
2706     if (ifOp.getElseRegion().empty() ||
2707         !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2708       return failure();
2709 
2710     rewriter.startOpModification(ifOp);
2711     rewriter.eraseBlock(ifOp.getElseBlock());
2712     rewriter.finalizeOpModification(ifOp);
2713     return success();
2714   }
2715 };
2716 
2717 /// Removes affine.if cond if the condition is always true or false in certain
2718 /// trivial cases. Promotes the then/else block in the parent operation block.
2719 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2720   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
2721 
2722   LogicalResult matchAndRewrite(AffineIfOp op,
2723                                 PatternRewriter &rewriter) const override {
2724 
2725     auto isTriviallyFalse = [](IntegerSet iSet) {
2726       return iSet.isEmptyIntegerSet();
2727     };
2728 
2729     auto isTriviallyTrue = [](IntegerSet iSet) {
2730       return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2731               iSet.getConstraint(0) == 0);
2732     };
2733 
2734     IntegerSet affineIfConditions = op.getIntegerSet();
2735     Block *blockToMove;
2736     if (isTriviallyFalse(affineIfConditions)) {
2737       // The absence, or equivalently, the emptiness of the else region need not
2738       // be checked when affine.if is returning results because if an affine.if
2739       // operation is returning results, it always has a non-empty else region.
2740       if (op.getNumResults() == 0 && !op.hasElse()) {
2741         // If the else region is absent, or equivalently, empty, remove the
2742         // affine.if operation (which is not returning any results).
2743         rewriter.eraseOp(op);
2744         return success();
2745       }
2746       blockToMove = op.getElseBlock();
2747     } else if (isTriviallyTrue(affineIfConditions)) {
2748       blockToMove = op.getThenBlock();
2749     } else {
2750       return failure();
2751     }
2752     Operation *blockToMoveTerminator = blockToMove->getTerminator();
2753     // Promote the "blockToMove" block to the parent operation block between the
2754     // prologue and epilogue of "op".
2755     rewriter.inlineBlockBefore(blockToMove, op);
2756     // Replace the "op" operation with the operands of the
2757     // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2758     // the affine.yield operation present in the "blockToMove" block. It has no
2759     // operands when affine.if is not returning results and therefore, in that
2760     // case, replaceOp just erases "op". When affine.if is not returning
2761     // results, the affine.yield operation can be omitted. It gets inserted
2762     // implicitly.
2763     rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2764     // Erase the "blockToMoveTerminator" operation since it is now in the parent
2765     // operation block, which already has its own terminator.
2766     rewriter.eraseOp(blockToMoveTerminator);
2767     return success();
2768   }
2769 };
2770 } // namespace
2771 
2772 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2773 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2774 void AffineIfOp::getSuccessorRegions(
2775     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2776   // If the predecessor is an AffineIfOp, then branching into both `then` and
2777   // `else` region is valid.
2778   if (point.isParent()) {
2779     regions.reserve(2);
2780     regions.push_back(
2781         RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2782     // If the "else" region is empty, branch bach into parent.
2783     if (getElseRegion().empty()) {
2784       regions.push_back(getResults());
2785     } else {
2786       regions.push_back(
2787           RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2788     }
2789     return;
2790   }
2791 
2792   // If the predecessor is the `else`/`then` region, then branching into parent
2793   // op is valid.
2794   regions.push_back(RegionSuccessor(getResults()));
2795 }
2796 
2797 LogicalResult AffineIfOp::verify() {
2798   // Verify that we have a condition attribute.
2799   // FIXME: This should be specified in the arguments list in ODS.
2800   auto conditionAttr =
2801       (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2802   if (!conditionAttr)
2803     return emitOpError("requires an integer set attribute named 'condition'");
2804 
2805   // Verify that there are enough operands for the condition.
2806   IntegerSet condition = conditionAttr.getValue();
2807   if (getNumOperands() != condition.getNumInputs())
2808     return emitOpError("operand count and condition integer set dimension and "
2809                        "symbol count must match");
2810 
2811   // Verify that the operands are valid dimension/symbols.
2812   if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2813                                            condition.getNumDims())))
2814     return failure();
2815 
2816   return success();
2817 }
2818 
2819 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2820   // Parse the condition attribute set.
2821   IntegerSetAttr conditionAttr;
2822   unsigned numDims;
2823   if (parser.parseAttribute(conditionAttr,
2824                             AffineIfOp::getConditionAttrStrName(),
2825                             result.attributes) ||
2826       parseDimAndSymbolList(parser, result.operands, numDims))
2827     return failure();
2828 
2829   // Verify the condition operands.
2830   auto set = conditionAttr.getValue();
2831   if (set.getNumDims() != numDims)
2832     return parser.emitError(
2833         parser.getNameLoc(),
2834         "dim operand count and integer set dim count must match");
2835   if (numDims + set.getNumSymbols() != result.operands.size())
2836     return parser.emitError(
2837         parser.getNameLoc(),
2838         "symbol operand count and integer set symbol count must match");
2839 
2840   if (parser.parseOptionalArrowTypeList(result.types))
2841     return failure();
2842 
2843   // Create the regions for 'then' and 'else'.  The latter must be created even
2844   // if it remains empty for the validity of the operation.
2845   result.regions.reserve(2);
2846   Region *thenRegion = result.addRegion();
2847   Region *elseRegion = result.addRegion();
2848 
2849   // Parse the 'then' region.
2850   if (parser.parseRegion(*thenRegion, {}, {}))
2851     return failure();
2852   AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2853                                result.location);
2854 
2855   // If we find an 'else' keyword then parse the 'else' region.
2856   if (!parser.parseOptionalKeyword("else")) {
2857     if (parser.parseRegion(*elseRegion, {}, {}))
2858       return failure();
2859     AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2860                                  result.location);
2861   }
2862 
2863   // Parse the optional attribute list.
2864   if (parser.parseOptionalAttrDict(result.attributes))
2865     return failure();
2866 
2867   return success();
2868 }
2869 
2870 void AffineIfOp::print(OpAsmPrinter &p) {
2871   auto conditionAttr =
2872       (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2873   p << " " << conditionAttr;
2874   printDimAndSymbolList(operand_begin(), operand_end(),
2875                         conditionAttr.getValue().getNumDims(), p);
2876   p.printOptionalArrowTypeList(getResultTypes());
2877   p << ' ';
2878   p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2879                 /*printBlockTerminators=*/getNumResults());
2880 
2881   // Print the 'else' regions if it has any blocks.
2882   auto &elseRegion = this->getElseRegion();
2883   if (!elseRegion.empty()) {
2884     p << " else ";
2885     p.printRegion(elseRegion,
2886                   /*printEntryBlockArgs=*/false,
2887                   /*printBlockTerminators=*/getNumResults());
2888   }
2889 
2890   // Print the attribute list.
2891   p.printOptionalAttrDict((*this)->getAttrs(),
2892                           /*elidedAttrs=*/getConditionAttrStrName());
2893 }
2894 
2895 IntegerSet AffineIfOp::getIntegerSet() {
2896   return (*this)
2897       ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2898       .getValue();
2899 }
2900 
2901 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2902   (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2903 }
2904 
2905 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2906   setIntegerSet(set);
2907   (*this)->setOperands(operands);
2908 }
2909 
2910 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2911                        TypeRange resultTypes, IntegerSet set, ValueRange args,
2912                        bool withElseRegion) {
2913   assert(resultTypes.empty() || withElseRegion);
2914   OpBuilder::InsertionGuard guard(builder);
2915 
2916   result.addTypes(resultTypes);
2917   result.addOperands(args);
2918   result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2919 
2920   Region *thenRegion = result.addRegion();
2921   builder.createBlock(thenRegion);
2922   if (resultTypes.empty())
2923     AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2924 
2925   Region *elseRegion = result.addRegion();
2926   if (withElseRegion) {
2927     builder.createBlock(elseRegion);
2928     if (resultTypes.empty())
2929       AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2930   }
2931 }
2932 
2933 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2934                        IntegerSet set, ValueRange args, bool withElseRegion) {
2935   AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2936                     withElseRegion);
2937 }
2938 
2939 /// Compose any affine.apply ops feeding into `operands` of the integer set
2940 /// `set` by composing the maps of such affine.apply ops with the integer
2941 /// set constraints.
2942 static void composeSetAndOperands(IntegerSet &set,
2943                                   SmallVectorImpl<Value> &operands) {
2944   // We will simply reuse the API of the map composition by viewing the LHSs of
2945   // the equalities and inequalities of `set` as the affine exprs of an affine
2946   // map. Convert to equivalent map, compose, and convert back to set.
2947   auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
2948                             set.getConstraints(), set.getContext());
2949   // Check if any composition is possible.
2950   if (llvm::none_of(operands,
2951                     [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2952     return;
2953 
2954   composeAffineMapAndOperands(&map, &operands);
2955   set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
2956                         set.getEqFlags());
2957 }
2958 
2959 /// Canonicalize an affine if op's conditional (integer set + operands).
2960 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2961   auto set = getIntegerSet();
2962   SmallVector<Value, 4> operands(getOperands());
2963   composeSetAndOperands(set, operands);
2964   canonicalizeSetAndOperands(&set, &operands);
2965 
2966   // Check if the canonicalization or composition led to any change.
2967   if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2968     return failure();
2969 
2970   setConditional(set, operands);
2971   return success();
2972 }
2973 
2974 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2975                                              MLIRContext *context) {
2976   results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2977 }
2978 
2979 //===----------------------------------------------------------------------===//
2980 // AffineLoadOp
2981 //===----------------------------------------------------------------------===//
2982 
2983 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2984                          AffineMap map, ValueRange operands) {
2985   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2986   result.addOperands(operands);
2987   if (map)
2988     result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2989   auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
2990   result.types.push_back(memrefType.getElementType());
2991 }
2992 
2993 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2994                          Value memref, AffineMap map, ValueRange mapOperands) {
2995   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2996   result.addOperands(memref);
2997   result.addOperands(mapOperands);
2998   auto memrefType = llvm::cast<MemRefType>(memref.getType());
2999   result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3000   result.types.push_back(memrefType.getElementType());
3001 }
3002 
3003 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3004                          Value memref, ValueRange indices) {
3005   auto memrefType = llvm::cast<MemRefType>(memref.getType());
3006   int64_t rank = memrefType.getRank();
3007   // Create identity map for memrefs with at least one dimension or () -> ()
3008   // for zero-dimensional memrefs.
3009   auto map =
3010       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3011   build(builder, result, memref, map, indices);
3012 }
3013 
3014 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3015   auto &builder = parser.getBuilder();
3016   auto indexTy = builder.getIndexType();
3017 
3018   MemRefType type;
3019   OpAsmParser::UnresolvedOperand memrefInfo;
3020   AffineMapAttr mapAttr;
3021   SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3022   return failure(
3023       parser.parseOperand(memrefInfo) ||
3024       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3025                                     AffineLoadOp::getMapAttrStrName(),
3026                                     result.attributes) ||
3027       parser.parseOptionalAttrDict(result.attributes) ||
3028       parser.parseColonType(type) ||
3029       parser.resolveOperand(memrefInfo, type, result.operands) ||
3030       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3031       parser.addTypeToList(type.getElementType(), result.types));
3032 }
3033 
3034 void AffineLoadOp::print(OpAsmPrinter &p) {
3035   p << " " << getMemRef() << '[';
3036   if (AffineMapAttr mapAttr =
3037           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3038     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3039   p << ']';
3040   p.printOptionalAttrDict((*this)->getAttrs(),
3041                           /*elidedAttrs=*/{getMapAttrStrName()});
3042   p << " : " << getMemRefType();
3043 }
3044 
3045 /// Verify common indexing invariants of affine.load, affine.store,
3046 /// affine.vector_load and affine.vector_store.
3047 static LogicalResult
3048 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
3049                        Operation::operand_range mapOperands,
3050                        MemRefType memrefType, unsigned numIndexOperands) {
3051   AffineMap map = mapAttr.getValue();
3052   if (map.getNumResults() != memrefType.getRank())
3053     return op->emitOpError("affine map num results must equal memref rank");
3054   if (map.getNumInputs() != numIndexOperands)
3055     return op->emitOpError("expects as many subscripts as affine map inputs");
3056 
3057   Region *scope = getAffineScope(op);
3058   for (auto idx : mapOperands) {
3059     if (!idx.getType().isIndex())
3060       return op->emitOpError("index to load must have 'index' type");
3061     if (!isValidAffineIndexOperand(idx, scope))
3062       return op->emitOpError(
3063           "index must be a valid dimension or symbol identifier");
3064   }
3065 
3066   return success();
3067 }
3068 
3069 LogicalResult AffineLoadOp::verify() {
3070   auto memrefType = getMemRefType();
3071   if (getType() != memrefType.getElementType())
3072     return emitOpError("result type must match element type of memref");
3073 
3074   if (failed(verifyMemoryOpIndexing(
3075           getOperation(),
3076           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3077           getMapOperands(), memrefType,
3078           /*numIndexOperands=*/getNumOperands() - 1)))
3079     return failure();
3080 
3081   return success();
3082 }
3083 
3084 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3085                                                MLIRContext *context) {
3086   results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3087 }
3088 
3089 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3090   /// load(memrefcast) -> load
3091   if (succeeded(memref::foldMemRefCast(*this)))
3092     return getResult();
3093 
3094   // Fold load from a global constant memref.
3095   auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3096   if (!getGlobalOp)
3097     return {};
3098   // Get to the memref.global defining the symbol.
3099   auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3100   if (!symbolTableOp)
3101     return {};
3102   auto global = dyn_cast_or_null<memref::GlobalOp>(
3103       SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3104   if (!global)
3105     return {};
3106 
3107   // Check if the global memref is a constant.
3108   auto cstAttr =
3109       llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3110   if (!cstAttr)
3111     return {};
3112   // If it's a splat constant, we can fold irrespective of indices.
3113   if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3114     return splatAttr.getSplatValue<Attribute>();
3115   // Otherwise, we can fold only if we know the indices.
3116   if (!getAffineMap().isConstant())
3117     return {};
3118   auto indices = llvm::to_vector<4>(
3119       llvm::map_range(getAffineMap().getConstantResults(),
3120                       [](int64_t v) -> uint64_t { return v; }));
3121   return cstAttr.getValues<Attribute>()[indices];
3122 }
3123 
3124 //===----------------------------------------------------------------------===//
3125 // AffineStoreOp
3126 //===----------------------------------------------------------------------===//
3127 
3128 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3129                           Value valueToStore, Value memref, AffineMap map,
3130                           ValueRange mapOperands) {
3131   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3132   result.addOperands(valueToStore);
3133   result.addOperands(memref);
3134   result.addOperands(mapOperands);
3135   result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3136 }
3137 
3138 // Use identity map.
3139 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3140                           Value valueToStore, Value memref,
3141                           ValueRange indices) {
3142   auto memrefType = llvm::cast<MemRefType>(memref.getType());
3143   int64_t rank = memrefType.getRank();
3144   // Create identity map for memrefs with at least one dimension or () -> ()
3145   // for zero-dimensional memrefs.
3146   auto map =
3147       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3148   build(builder, result, valueToStore, memref, map, indices);
3149 }
3150 
3151 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3152   auto indexTy = parser.getBuilder().getIndexType();
3153 
3154   MemRefType type;
3155   OpAsmParser::UnresolvedOperand storeValueInfo;
3156   OpAsmParser::UnresolvedOperand memrefInfo;
3157   AffineMapAttr mapAttr;
3158   SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3159   return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3160                  parser.parseOperand(memrefInfo) ||
3161                  parser.parseAffineMapOfSSAIds(
3162                      mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3163                      result.attributes) ||
3164                  parser.parseOptionalAttrDict(result.attributes) ||
3165                  parser.parseColonType(type) ||
3166                  parser.resolveOperand(storeValueInfo, type.getElementType(),
3167                                        result.operands) ||
3168                  parser.resolveOperand(memrefInfo, type, result.operands) ||
3169                  parser.resolveOperands(mapOperands, indexTy, result.operands));
3170 }
3171 
3172 void AffineStoreOp::print(OpAsmPrinter &p) {
3173   p << " " << getValueToStore();
3174   p << ", " << getMemRef() << '[';
3175   if (AffineMapAttr mapAttr =
3176           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3177     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3178   p << ']';
3179   p.printOptionalAttrDict((*this)->getAttrs(),
3180                           /*elidedAttrs=*/{getMapAttrStrName()});
3181   p << " : " << getMemRefType();
3182 }
3183 
3184 LogicalResult AffineStoreOp::verify() {
3185   // The value to store must have the same type as memref element type.
3186   auto memrefType = getMemRefType();
3187   if (getValueToStore().getType() != memrefType.getElementType())
3188     return emitOpError(
3189         "value to store must have the same type as memref element type");
3190 
3191   if (failed(verifyMemoryOpIndexing(
3192           getOperation(),
3193           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3194           getMapOperands(), memrefType,
3195           /*numIndexOperands=*/getNumOperands() - 2)))
3196     return failure();
3197 
3198   return success();
3199 }
3200 
3201 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3202                                                 MLIRContext *context) {
3203   results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3204 }
3205 
3206 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3207                                   SmallVectorImpl<OpFoldResult> &results) {
3208   /// store(memrefcast) -> store
3209   return memref::foldMemRefCast(*this, getValueToStore());
3210 }
3211 
3212 //===----------------------------------------------------------------------===//
3213 // AffineMinMaxOpBase
3214 //===----------------------------------------------------------------------===//
3215 
3216 template <typename T>
3217 static LogicalResult verifyAffineMinMaxOp(T op) {
3218   // Verify that operand count matches affine map dimension and symbol count.
3219   if (op.getNumOperands() !=
3220       op.getMap().getNumDims() + op.getMap().getNumSymbols())
3221     return op.emitOpError(
3222         "operand count and affine map dimension and symbol count must match");
3223 
3224   if (op.getMap().getNumResults() == 0)
3225     return op.emitOpError("affine map expect at least one result");
3226   return success();
3227 }
3228 
3229 template <typename T>
3230 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3231   p << ' ' << op->getAttr(T::getMapAttrStrName());
3232   auto operands = op.getOperands();
3233   unsigned numDims = op.getMap().getNumDims();
3234   p << '(' << operands.take_front(numDims) << ')';
3235 
3236   if (operands.size() != numDims)
3237     p << '[' << operands.drop_front(numDims) << ']';
3238   p.printOptionalAttrDict(op->getAttrs(),
3239                           /*elidedAttrs=*/{T::getMapAttrStrName()});
3240 }
3241 
3242 template <typename T>
3243 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3244                                        OperationState &result) {
3245   auto &builder = parser.getBuilder();
3246   auto indexType = builder.getIndexType();
3247   SmallVector<OpAsmParser::UnresolvedOperand, 8> dimInfos;
3248   SmallVector<OpAsmParser::UnresolvedOperand, 8> symInfos;
3249   AffineMapAttr mapAttr;
3250   return failure(
3251       parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3252                             result.attributes) ||
3253       parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3254       parser.parseOperandList(symInfos,
3255                               OpAsmParser::Delimiter::OptionalSquare) ||
3256       parser.parseOptionalAttrDict(result.attributes) ||
3257       parser.resolveOperands(dimInfos, indexType, result.operands) ||
3258       parser.resolveOperands(symInfos, indexType, result.operands) ||
3259       parser.addTypeToList(indexType, result.types));
3260 }
3261 
3262 /// Fold an affine min or max operation with the given operands. The operand
3263 /// list may contain nulls, which are interpreted as the operand not being a
3264 /// constant.
3265 template <typename T>
3266 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
3267   static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3268                 "expected affine min or max op");
3269 
3270   // Fold the affine map.
3271   // TODO: Fold more cases:
3272   // min(some_affine, some_affine + constant, ...), etc.
3273   SmallVector<int64_t, 2> results;
3274   auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3275 
3276   if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3277     return op.getOperand(0);
3278 
3279   // If some of the map results are not constant, try changing the map in-place.
3280   if (results.empty()) {
3281     // If the map is the same, report that folding did not happen.
3282     if (foldedMap == op.getMap())
3283       return {};
3284     op->setAttr("map", AffineMapAttr::get(foldedMap));
3285     return op.getResult();
3286   }
3287 
3288   // Otherwise, completely fold the op into a constant.
3289   auto resultIt = std::is_same<T, AffineMinOp>::value
3290                       ? llvm::min_element(results)
3291                       : llvm::max_element(results);
3292   if (resultIt == results.end())
3293     return {};
3294   return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3295 }
3296 
3297 /// Remove duplicated expressions in affine min/max ops.
3298 template <typename T>
3299 struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
3300   using OpRewritePattern<T>::OpRewritePattern;
3301 
3302   LogicalResult matchAndRewrite(T affineOp,
3303                                 PatternRewriter &rewriter) const override {
3304     AffineMap oldMap = affineOp.getAffineMap();
3305 
3306     SmallVector<AffineExpr, 4> newExprs;
3307     for (AffineExpr expr : oldMap.getResults()) {
3308       // This is a linear scan over newExprs, but it should be fine given that
3309       // we typically just have a few expressions per op.
3310       if (!llvm::is_contained(newExprs, expr))
3311         newExprs.push_back(expr);
3312     }
3313 
3314     if (newExprs.size() == oldMap.getNumResults())
3315       return failure();
3316 
3317     auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3318                                  newExprs, rewriter.getContext());
3319     rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3320 
3321     return success();
3322   }
3323 };
3324 
3325 /// Merge an affine min/max op to its consumers if its consumer is also an
3326 /// affine min/max op.
3327 ///
3328 /// This pattern requires the producer affine min/max op is bound to a
3329 /// dimension/symbol that is used as a standalone expression in the consumer
3330 /// affine op's map.
3331 ///
3332 /// For example, a pattern like the following:
3333 ///
3334 ///   %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3335 ///   %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3336 ///
3337 /// Can be turned into:
3338 ///
3339 ///   %1 = affine.min affine_map<
3340 ///          ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3341 template <typename T>
3342 struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
3343   using OpRewritePattern<T>::OpRewritePattern;
3344 
3345   LogicalResult matchAndRewrite(T affineOp,
3346                                 PatternRewriter &rewriter) const override {
3347     AffineMap oldMap = affineOp.getAffineMap();
3348     ValueRange dimOperands =
3349         affineOp.getMapOperands().take_front(oldMap.getNumDims());
3350     ValueRange symOperands =
3351         affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3352 
3353     auto newDimOperands = llvm::to_vector<8>(dimOperands);
3354     auto newSymOperands = llvm::to_vector<8>(symOperands);
3355     SmallVector<AffineExpr, 4> newExprs;
3356     SmallVector<T, 4> producerOps;
3357 
3358     // Go over each expression to see whether it's a single dimension/symbol
3359     // with the corresponding operand which is the result of another affine
3360     // min/max op. If So it can be merged into this affine op.
3361     for (AffineExpr expr : oldMap.getResults()) {
3362       if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3363         Value symValue = symOperands[symExpr.getPosition()];
3364         if (auto producerOp = symValue.getDefiningOp<T>()) {
3365           producerOps.push_back(producerOp);
3366           continue;
3367         }
3368       } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3369         Value dimValue = dimOperands[dimExpr.getPosition()];
3370         if (auto producerOp = dimValue.getDefiningOp<T>()) {
3371           producerOps.push_back(producerOp);
3372           continue;
3373         }
3374       }
3375       // For the above cases we will remove the expression by merging the
3376       // producer affine min/max's affine expressions. Otherwise we need to
3377       // keep the existing expression.
3378       newExprs.push_back(expr);
3379     }
3380 
3381     if (producerOps.empty())
3382       return failure();
3383 
3384     unsigned numUsedDims = oldMap.getNumDims();
3385     unsigned numUsedSyms = oldMap.getNumSymbols();
3386 
3387     // Now go over all producer affine ops and merge their expressions.
3388     for (T producerOp : producerOps) {
3389       AffineMap producerMap = producerOp.getAffineMap();
3390       unsigned numProducerDims = producerMap.getNumDims();
3391       unsigned numProducerSyms = producerMap.getNumSymbols();
3392 
3393       // Collect all dimension/symbol values.
3394       ValueRange dimValues =
3395           producerOp.getMapOperands().take_front(numProducerDims);
3396       ValueRange symValues =
3397           producerOp.getMapOperands().take_back(numProducerSyms);
3398       newDimOperands.append(dimValues.begin(), dimValues.end());
3399       newSymOperands.append(symValues.begin(), symValues.end());
3400 
3401       // For expressions we need to shift to avoid overlap.
3402       for (AffineExpr expr : producerMap.getResults()) {
3403         newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3404                                .shiftSymbols(numProducerSyms, numUsedSyms));
3405       }
3406 
3407       numUsedDims += numProducerDims;
3408       numUsedSyms += numProducerSyms;
3409     }
3410 
3411     auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3412                                  rewriter.getContext());
3413     auto newOperands =
3414         llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3415     rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3416 
3417     return success();
3418   }
3419 };
3420 
3421 /// Canonicalize the result expression order of an affine map and return success
3422 /// if the order changed.
3423 ///
3424 /// The function flattens the map's affine expressions to coefficient arrays and
3425 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3426 /// for every dimension/symbol and a constant term. The canonicalization fails
3427 /// if a result expression is not pure or if the flattening requires local
3428 /// variables that, unlike dimensions and symbols, have no global order.
3429 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3430   SmallVector<SmallVector<int64_t>> flattenedExprs;
3431   for (const AffineExpr &resultExpr : map.getResults()) {
3432     // Fail if the expression is not pure.
3433     if (!resultExpr.isPureAffine())
3434       return failure();
3435 
3436     SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3437     auto flattenResult = flattener.walkPostOrder(resultExpr);
3438     if (failed(flattenResult))
3439       return failure();
3440 
3441     // Fail if the flattened expression has local variables.
3442     if (flattener.operandExprStack.back().size() !=
3443         map.getNumDims() + map.getNumSymbols() + 1)
3444       return failure();
3445 
3446     flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3447                                 flattener.operandExprStack.back().end());
3448   }
3449 
3450   // Fail if sorting is not necessary.
3451   if (llvm::is_sorted(flattenedExprs))
3452     return failure();
3453 
3454   // Reorder the result expressions according to their flattened form.
3455   SmallVector<unsigned> resultPermutation =
3456       llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3457   llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3458     return flattenedExprs[lhs] < flattenedExprs[rhs];
3459   });
3460   SmallVector<AffineExpr> newExprs;
3461   for (unsigned idx : resultPermutation)
3462     newExprs.push_back(map.getResult(idx));
3463 
3464   map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3465                        map.getContext());
3466   return success();
3467 }
3468 
3469 /// Canonicalize the affine map result expression order of an affine min/max
3470 /// operation.
3471 ///
3472 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3473 /// expressions and replaces the operation if the order changed.
3474 ///
3475 /// For example, the following operation:
3476 ///
3477 ///   %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3478 ///
3479 /// Turns into:
3480 ///
3481 ///   %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3482 template <typename T>
3483 struct CanonicalizeAffineMinMaxOpExprAndTermOrder : public OpRewritePattern<T> {
3484   using OpRewritePattern<T>::OpRewritePattern;
3485 
3486   LogicalResult matchAndRewrite(T affineOp,
3487                                 PatternRewriter &rewriter) const override {
3488     AffineMap map = affineOp.getAffineMap();
3489     if (failed(canonicalizeMapExprAndTermOrder(map)))
3490       return failure();
3491     rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3492     return success();
3493   }
3494 };
3495 
3496 template <typename T>
3497 struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
3498   using OpRewritePattern<T>::OpRewritePattern;
3499 
3500   LogicalResult matchAndRewrite(T affineOp,
3501                                 PatternRewriter &rewriter) const override {
3502     if (affineOp.getMap().getNumResults() != 1)
3503       return failure();
3504     rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3505                                                affineOp.getOperands());
3506     return success();
3507   }
3508 };
3509 
3510 //===----------------------------------------------------------------------===//
3511 // AffineMinOp
3512 //===----------------------------------------------------------------------===//
3513 //
3514 //   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3515 //
3516 
3517 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3518   return foldMinMaxOp(*this, adaptor.getOperands());
3519 }
3520 
3521 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3522                                               MLIRContext *context) {
3523   patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
3524                DeduplicateAffineMinMaxExpressions<AffineMinOp>,
3525                MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3526                CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMinOp>>(
3527       context);
3528 }
3529 
3530 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3531 
3532 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3533   return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3534 }
3535 
3536 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3537 
3538 //===----------------------------------------------------------------------===//
3539 // AffineMaxOp
3540 //===----------------------------------------------------------------------===//
3541 //
3542 //   %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3543 //
3544 
3545 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3546   return foldMinMaxOp(*this, adaptor.getOperands());
3547 }
3548 
3549 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3550                                               MLIRContext *context) {
3551   patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
3552                DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
3553                MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3554                CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMaxOp>>(
3555       context);
3556 }
3557 
3558 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3559 
3560 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3561   return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3562 }
3563 
3564 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3565 
3566 //===----------------------------------------------------------------------===//
3567 // AffinePrefetchOp
3568 //===----------------------------------------------------------------------===//
3569 
3570 //
3571 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3572 //
3573 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3574                                     OperationState &result) {
3575   auto &builder = parser.getBuilder();
3576   auto indexTy = builder.getIndexType();
3577 
3578   MemRefType type;
3579   OpAsmParser::UnresolvedOperand memrefInfo;
3580   IntegerAttr hintInfo;
3581   auto i32Type = parser.getBuilder().getIntegerType(32);
3582   StringRef readOrWrite, cacheType;
3583 
3584   AffineMapAttr mapAttr;
3585   SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3586   if (parser.parseOperand(memrefInfo) ||
3587       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3588                                     AffinePrefetchOp::getMapAttrStrName(),
3589                                     result.attributes) ||
3590       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3591       parser.parseComma() || parser.parseKeyword("locality") ||
3592       parser.parseLess() ||
3593       parser.parseAttribute(hintInfo, i32Type,
3594                             AffinePrefetchOp::getLocalityHintAttrStrName(),
3595                             result.attributes) ||
3596       parser.parseGreater() || parser.parseComma() ||
3597       parser.parseKeyword(&cacheType) ||
3598       parser.parseOptionalAttrDict(result.attributes) ||
3599       parser.parseColonType(type) ||
3600       parser.resolveOperand(memrefInfo, type, result.operands) ||
3601       parser.resolveOperands(mapOperands, indexTy, result.operands))
3602     return failure();
3603 
3604   if (readOrWrite != "read" && readOrWrite != "write")
3605     return parser.emitError(parser.getNameLoc(),
3606                             "rw specifier has to be 'read' or 'write'");
3607   result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3608                       parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3609 
3610   if (cacheType != "data" && cacheType != "instr")
3611     return parser.emitError(parser.getNameLoc(),
3612                             "cache type has to be 'data' or 'instr'");
3613 
3614   result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3615                       parser.getBuilder().getBoolAttr(cacheType == "data"));
3616 
3617   return success();
3618 }
3619 
3620 void AffinePrefetchOp::print(OpAsmPrinter &p) {
3621   p << " " << getMemref() << '[';
3622   AffineMapAttr mapAttr =
3623       (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3624   if (mapAttr)
3625     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3626   p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3627     << "locality<" << getLocalityHint() << ">, "
3628     << (getIsDataCache() ? "data" : "instr");
3629   p.printOptionalAttrDict(
3630       (*this)->getAttrs(),
3631       /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3632                        getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3633   p << " : " << getMemRefType();
3634 }
3635 
3636 LogicalResult AffinePrefetchOp::verify() {
3637   auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3638   if (mapAttr) {
3639     AffineMap map = mapAttr.getValue();
3640     if (map.getNumResults() != getMemRefType().getRank())
3641       return emitOpError("affine.prefetch affine map num results must equal"
3642                          " memref rank");
3643     if (map.getNumInputs() + 1 != getNumOperands())
3644       return emitOpError("too few operands");
3645   } else {
3646     if (getNumOperands() != 1)
3647       return emitOpError("too few operands");
3648   }
3649 
3650   Region *scope = getAffineScope(*this);
3651   for (auto idx : getMapOperands()) {
3652     if (!isValidAffineIndexOperand(idx, scope))
3653       return emitOpError(
3654           "index must be a valid dimension or symbol identifier");
3655   }
3656   return success();
3657 }
3658 
3659 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3660                                                    MLIRContext *context) {
3661   // prefetch(memrefcast) -> prefetch
3662   results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3663 }
3664 
3665 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3666                                      SmallVectorImpl<OpFoldResult> &results) {
3667   /// prefetch(memrefcast) -> prefetch
3668   return memref::foldMemRefCast(*this);
3669 }
3670 
3671 //===----------------------------------------------------------------------===//
3672 // AffineParallelOp
3673 //===----------------------------------------------------------------------===//
3674 
3675 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3676                              TypeRange resultTypes,
3677                              ArrayRef<arith::AtomicRMWKind> reductions,
3678                              ArrayRef<int64_t> ranges) {
3679   SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3680   auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3681     return builder.getConstantAffineMap(value);
3682   }));
3683   SmallVector<int64_t> steps(ranges.size(), 1);
3684   build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3685         /*ubArgs=*/{}, steps);
3686 }
3687 
3688 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3689                              TypeRange resultTypes,
3690                              ArrayRef<arith::AtomicRMWKind> reductions,
3691                              ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3692                              ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3693                              ArrayRef<int64_t> steps) {
3694   assert(llvm::all_of(lbMaps,
3695                       [lbMaps](AffineMap m) {
3696                         return m.getNumDims() == lbMaps[0].getNumDims() &&
3697                                m.getNumSymbols() == lbMaps[0].getNumSymbols();
3698                       }) &&
3699          "expected all lower bounds maps to have the same number of dimensions "
3700          "and symbols");
3701   assert(llvm::all_of(ubMaps,
3702                       [ubMaps](AffineMap m) {
3703                         return m.getNumDims() == ubMaps[0].getNumDims() &&
3704                                m.getNumSymbols() == ubMaps[0].getNumSymbols();
3705                       }) &&
3706          "expected all upper bounds maps to have the same number of dimensions "
3707          "and symbols");
3708   assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3709          "expected lower bound maps to have as many inputs as lower bound "
3710          "operands");
3711   assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3712          "expected upper bound maps to have as many inputs as upper bound "
3713          "operands");
3714 
3715   OpBuilder::InsertionGuard guard(builder);
3716   result.addTypes(resultTypes);
3717 
3718   // Convert the reductions to integer attributes.
3719   SmallVector<Attribute, 4> reductionAttrs;
3720   for (arith::AtomicRMWKind reduction : reductions)
3721     reductionAttrs.push_back(
3722         builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3723   result.addAttribute(getReductionsAttrStrName(),
3724                       builder.getArrayAttr(reductionAttrs));
3725 
3726   // Concatenates maps defined in the same input space (same dimensions and
3727   // symbols), assumes there is at least one map.
3728   auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3729                                         SmallVectorImpl<int32_t> &groups) {
3730     if (maps.empty())
3731       return AffineMap::get(builder.getContext());
3732     SmallVector<AffineExpr> exprs;
3733     groups.reserve(groups.size() + maps.size());
3734     exprs.reserve(maps.size());
3735     for (AffineMap m : maps) {
3736       llvm::append_range(exprs, m.getResults());
3737       groups.push_back(m.getNumResults());
3738     }
3739     return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3740                           maps[0].getContext());
3741   };
3742 
3743   // Set up the bounds.
3744   SmallVector<int32_t> lbGroups, ubGroups;
3745   AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3746   AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3747   result.addAttribute(getLowerBoundsMapAttrStrName(),
3748                       AffineMapAttr::get(lbMap));
3749   result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3750                       builder.getI32TensorAttr(lbGroups));
3751   result.addAttribute(getUpperBoundsMapAttrStrName(),
3752                       AffineMapAttr::get(ubMap));
3753   result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3754                       builder.getI32TensorAttr(ubGroups));
3755   result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3756   result.addOperands(lbArgs);
3757   result.addOperands(ubArgs);
3758 
3759   // Create a region and a block for the body.
3760   auto *bodyRegion = result.addRegion();
3761   Block *body = builder.createBlock(bodyRegion);
3762 
3763   // Add all the block arguments.
3764   for (unsigned i = 0, e = steps.size(); i < e; ++i)
3765     body->addArgument(IndexType::get(builder.getContext()), result.location);
3766   if (resultTypes.empty())
3767     ensureTerminator(*bodyRegion, builder, result.location);
3768 }
3769 
3770 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3771   return {&getRegion()};
3772 }
3773 
3774 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3775 
3776 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3777   return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3778 }
3779 
3780 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3781   return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3782 }
3783 
3784 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3785   auto values = getLowerBoundsGroups().getValues<int32_t>();
3786   unsigned start = 0;
3787   for (unsigned i = 0; i < pos; ++i)
3788     start += values[i];
3789   return getLowerBoundsMap().getSliceMap(start, values[pos]);
3790 }
3791 
3792 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3793   auto values = getUpperBoundsGroups().getValues<int32_t>();
3794   unsigned start = 0;
3795   for (unsigned i = 0; i < pos; ++i)
3796     start += values[i];
3797   return getUpperBoundsMap().getSliceMap(start, values[pos]);
3798 }
3799 
3800 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3801   return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3802 }
3803 
3804 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3805   return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3806 }
3807 
3808 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3809   if (hasMinMaxBounds())
3810     return std::nullopt;
3811 
3812   // Try to convert all the ranges to constant expressions.
3813   SmallVector<int64_t, 8> out;
3814   AffineValueMap rangesValueMap;
3815   AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3816                              &rangesValueMap);
3817   out.reserve(rangesValueMap.getNumResults());
3818   for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3819     auto expr = rangesValueMap.getResult(i);
3820     auto cst = dyn_cast<AffineConstantExpr>(expr);
3821     if (!cst)
3822       return std::nullopt;
3823     out.push_back(cst.getValue());
3824   }
3825   return out;
3826 }
3827 
3828 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3829 
3830 OpBuilder AffineParallelOp::getBodyBuilder() {
3831   return OpBuilder(getBody(), std::prev(getBody()->end()));
3832 }
3833 
3834 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3835   assert(lbOperands.size() == map.getNumInputs() &&
3836          "operands to map must match number of inputs");
3837 
3838   auto ubOperands = getUpperBoundsOperands();
3839 
3840   SmallVector<Value, 4> newOperands(lbOperands);
3841   newOperands.append(ubOperands.begin(), ubOperands.end());
3842   (*this)->setOperands(newOperands);
3843 
3844   setLowerBoundsMapAttr(AffineMapAttr::get(map));
3845 }
3846 
3847 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3848   assert(ubOperands.size() == map.getNumInputs() &&
3849          "operands to map must match number of inputs");
3850 
3851   SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3852   newOperands.append(ubOperands.begin(), ubOperands.end());
3853   (*this)->setOperands(newOperands);
3854 
3855   setUpperBoundsMapAttr(AffineMapAttr::get(map));
3856 }
3857 
3858 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3859   setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3860 }
3861 
3862 // check whether resultType match op or not in affine.parallel
3863 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3864                                            arith::AtomicRMWKind op) {
3865   switch (op) {
3866   case arith::AtomicRMWKind::addf:
3867     return isa<FloatType>(resultType);
3868   case arith::AtomicRMWKind::addi:
3869     return isa<IntegerType>(resultType);
3870   case arith::AtomicRMWKind::assign:
3871     return true;
3872   case arith::AtomicRMWKind::mulf:
3873     return isa<FloatType>(resultType);
3874   case arith::AtomicRMWKind::muli:
3875     return isa<IntegerType>(resultType);
3876   case arith::AtomicRMWKind::maximumf:
3877     return isa<FloatType>(resultType);
3878   case arith::AtomicRMWKind::minimumf:
3879     return isa<FloatType>(resultType);
3880   case arith::AtomicRMWKind::maxs: {
3881     auto intType = llvm::dyn_cast<IntegerType>(resultType);
3882     return intType && intType.isSigned();
3883   }
3884   case arith::AtomicRMWKind::mins: {
3885     auto intType = llvm::dyn_cast<IntegerType>(resultType);
3886     return intType && intType.isSigned();
3887   }
3888   case arith::AtomicRMWKind::maxu: {
3889     auto intType = llvm::dyn_cast<IntegerType>(resultType);
3890     return intType && intType.isUnsigned();
3891   }
3892   case arith::AtomicRMWKind::minu: {
3893     auto intType = llvm::dyn_cast<IntegerType>(resultType);
3894     return intType && intType.isUnsigned();
3895   }
3896   case arith::AtomicRMWKind::ori:
3897     return isa<IntegerType>(resultType);
3898   case arith::AtomicRMWKind::andi:
3899     return isa<IntegerType>(resultType);
3900   default:
3901     return false;
3902   }
3903 }
3904 
3905 LogicalResult AffineParallelOp::verify() {
3906   auto numDims = getNumDims();
3907   if (getLowerBoundsGroups().getNumElements() != numDims ||
3908       getUpperBoundsGroups().getNumElements() != numDims ||
3909       getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3910     return emitOpError() << "the number of region arguments ("
3911                          << getBody()->getNumArguments()
3912                          << ") and the number of map groups for lower ("
3913                          << getLowerBoundsGroups().getNumElements()
3914                          << ") and upper bound ("
3915                          << getUpperBoundsGroups().getNumElements()
3916                          << "), and the number of steps (" << getSteps().size()
3917                          << ") must all match";
3918   }
3919 
3920   unsigned expectedNumLBResults = 0;
3921   for (APInt v : getLowerBoundsGroups())
3922     expectedNumLBResults += v.getZExtValue();
3923   if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3924     return emitOpError() << "expected lower bounds map to have "
3925                          << expectedNumLBResults << " results";
3926   unsigned expectedNumUBResults = 0;
3927   for (APInt v : getUpperBoundsGroups())
3928     expectedNumUBResults += v.getZExtValue();
3929   if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3930     return emitOpError() << "expected upper bounds map to have "
3931                          << expectedNumUBResults << " results";
3932 
3933   if (getReductions().size() != getNumResults())
3934     return emitOpError("a reduction must be specified for each output");
3935 
3936   // Verify reduction ops are all valid and each result type matches reduction
3937   // ops
3938   for (auto it : llvm::enumerate((getReductions()))) {
3939     Attribute attr = it.value();
3940     auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3941     if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3942       return emitOpError("invalid reduction attribute");
3943     auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3944     if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
3945       return emitOpError("result type cannot match reduction attribute");
3946   }
3947 
3948   // Verify that the bound operands are valid dimension/symbols.
3949   /// Lower bounds.
3950   if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3951                                            getLowerBoundsMap().getNumDims())))
3952     return failure();
3953   /// Upper bounds.
3954   if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3955                                            getUpperBoundsMap().getNumDims())))
3956     return failure();
3957   return success();
3958 }
3959 
3960 LogicalResult AffineValueMap::canonicalize() {
3961   SmallVector<Value, 4> newOperands{operands};
3962   auto newMap = getAffineMap();
3963   composeAffineMapAndOperands(&newMap, &newOperands);
3964   if (newMap == getAffineMap() && newOperands == operands)
3965     return failure();
3966   reset(newMap, newOperands);
3967   return success();
3968 }
3969 
3970 /// Canonicalize the bounds of the given loop.
3971 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3972   AffineValueMap lb = op.getLowerBoundsValueMap();
3973   bool lbCanonicalized = succeeded(lb.canonicalize());
3974 
3975   AffineValueMap ub = op.getUpperBoundsValueMap();
3976   bool ubCanonicalized = succeeded(ub.canonicalize());
3977 
3978   // Any canonicalization change always leads to updated map(s).
3979   if (!lbCanonicalized && !ubCanonicalized)
3980     return failure();
3981 
3982   if (lbCanonicalized)
3983     op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3984   if (ubCanonicalized)
3985     op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3986 
3987   return success();
3988 }
3989 
3990 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3991                                      SmallVectorImpl<OpFoldResult> &results) {
3992   return canonicalizeLoopBounds(*this);
3993 }
3994 
3995 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3996 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3997 /// identifies which of the those expressions form max/min groups. `operands`
3998 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3999 /// or "max".
4000 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4001                              DenseIntElementsAttr group, ValueRange operands,
4002                              StringRef keyword) {
4003   AffineMap map = mapAttr.getValue();
4004   unsigned numDims = map.getNumDims();
4005   ValueRange dimOperands = operands.take_front(numDims);
4006   ValueRange symOperands = operands.drop_front(numDims);
4007   unsigned start = 0;
4008   for (llvm::APInt groupSize : group) {
4009     if (start != 0)
4010       p << ", ";
4011 
4012     unsigned size = groupSize.getZExtValue();
4013     if (size == 1) {
4014       p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4015       ++start;
4016     } else {
4017       p << keyword << '(';
4018       AffineMap submap = map.getSliceMap(start, size);
4019       p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4020       p << ')';
4021       start += size;
4022     }
4023   }
4024 }
4025 
4026 void AffineParallelOp::print(OpAsmPrinter &p) {
4027   p << " (" << getBody()->getArguments() << ") = (";
4028   printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4029                    getLowerBoundsOperands(), "max");
4030   p << ") to (";
4031   printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4032                    getUpperBoundsOperands(), "min");
4033   p << ')';
4034   SmallVector<int64_t, 8> steps = getSteps();
4035   bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4036   if (!elideSteps) {
4037     p << " step (";
4038     llvm::interleaveComma(steps, p);
4039     p << ')';
4040   }
4041   if (getNumResults()) {
4042     p << " reduce (";
4043     llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4044       arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4045           llvm::cast<IntegerAttr>(attr).getInt());
4046       p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4047     });
4048     p << ") -> (" << getResultTypes() << ")";
4049   }
4050 
4051   p << ' ';
4052   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4053                 /*printBlockTerminators=*/getNumResults());
4054   p.printOptionalAttrDict(
4055       (*this)->getAttrs(),
4056       /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4057                        AffineParallelOp::getLowerBoundsMapAttrStrName(),
4058                        AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4059                        AffineParallelOp::getUpperBoundsMapAttrStrName(),
4060                        AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4061                        AffineParallelOp::getStepsAttrStrName()});
4062 }
4063 
4064 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4065 /// unique operands. Also populates `replacements with affine expressions of
4066 /// `kind` that can be used to update affine maps previously accepting a
4067 /// `operands` to accept `uniqueOperands` instead.
4068 static ParseResult deduplicateAndResolveOperands(
4069     OpAsmParser &parser,
4070     ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
4071     SmallVectorImpl<Value> &uniqueOperands,
4072     SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4073   assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4074          "expected operands to be dim or symbol expression");
4075 
4076   Type indexType = parser.getBuilder().getIndexType();
4077   for (const auto &list : operands) {
4078     SmallVector<Value> valueOperands;
4079     if (parser.resolveOperands(list, indexType, valueOperands))
4080       return failure();
4081     for (Value operand : valueOperands) {
4082       unsigned pos = std::distance(uniqueOperands.begin(),
4083                                    llvm::find(uniqueOperands, operand));
4084       if (pos == uniqueOperands.size())
4085         uniqueOperands.push_back(operand);
4086       replacements.push_back(
4087           kind == AffineExprKind::DimId
4088               ? getAffineDimExpr(pos, parser.getContext())
4089               : getAffineSymbolExpr(pos, parser.getContext()));
4090     }
4091   }
4092   return success();
4093 }
4094 
4095 namespace {
4096 enum class MinMaxKind { Min, Max };
4097 } // namespace
4098 
4099 /// Parses an affine map that can contain a min/max for groups of its results,
4100 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4101 /// `result` attributes with the map (flat list of expressions) and the grouping
4102 /// (list of integers that specify how many expressions to put into each
4103 /// min/max) attributes. Deduplicates repeated operands.
4104 ///
4105 /// parallel-bound       ::= `(` parallel-group-list `)`
4106 /// parallel-group-list  ::= parallel-group (`,` parallel-group-list)?
4107 /// parallel-group       ::= simple-group | min-max-group
4108 /// simple-group         ::= expr-of-ssa-ids
4109 /// min-max-group        ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4110 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4111 ///
4112 /// Examples:
4113 ///   (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4114 ///   (%0, max(%1 - 2 * %2))
4115 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4116                                             OperationState &result,
4117                                             MinMaxKind kind) {
4118   // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4119   // see: https://reviews.llvm.org/D134227#3821753
4120   const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4121 
4122   StringRef mapName = kind == MinMaxKind::Min
4123                           ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4124                           : AffineParallelOp::getLowerBoundsMapAttrStrName();
4125   StringRef groupsName =
4126       kind == MinMaxKind::Min
4127           ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4128           : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4129 
4130   if (failed(parser.parseLParen()))
4131     return failure();
4132 
4133   if (succeeded(parser.parseOptionalRParen())) {
4134     result.addAttribute(
4135         mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4136     result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4137     return success();
4138   }
4139 
4140   SmallVector<AffineExpr> flatExprs;
4141   SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
4142   SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
4143   SmallVector<int32_t> numMapsPerGroup;
4144   SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
4145   auto parseOperands = [&]() {
4146     if (succeeded(parser.parseOptionalKeyword(
4147             kind == MinMaxKind::Min ? "min" : "max"))) {
4148       mapOperands.clear();
4149       AffineMapAttr map;
4150       if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4151                                                result.attributes,
4152                                                OpAsmParser::Delimiter::Paren)))
4153         return failure();
4154       result.attributes.erase(tmpAttrStrName);
4155       llvm::append_range(flatExprs, map.getValue().getResults());
4156       auto operandsRef = llvm::ArrayRef(mapOperands);
4157       auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4158       SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef);
4159       auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4160       SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef);
4161       flatDimOperands.append(map.getValue().getNumResults(), dims);
4162       flatSymOperands.append(map.getValue().getNumResults(), syms);
4163       numMapsPerGroup.push_back(map.getValue().getNumResults());
4164     } else {
4165       if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4166                                                 flatSymOperands.emplace_back(),
4167                                                 flatExprs.emplace_back())))
4168         return failure();
4169       numMapsPerGroup.push_back(1);
4170     }
4171     return success();
4172   };
4173   if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4174     return failure();
4175 
4176   unsigned totalNumDims = 0;
4177   unsigned totalNumSyms = 0;
4178   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4179     unsigned numDims = flatDimOperands[i].size();
4180     unsigned numSyms = flatSymOperands[i].size();
4181     flatExprs[i] = flatExprs[i]
4182                        .shiftDims(numDims, totalNumDims)
4183                        .shiftSymbols(numSyms, totalNumSyms);
4184     totalNumDims += numDims;
4185     totalNumSyms += numSyms;
4186   }
4187 
4188   // Deduplicate map operands.
4189   SmallVector<Value> dimOperands, symOperands;
4190   SmallVector<AffineExpr> dimRplacements, symRepacements;
4191   if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4192                                     dimRplacements, AffineExprKind::DimId) ||
4193       deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4194                                     symRepacements, AffineExprKind::SymbolId))
4195     return failure();
4196 
4197   result.operands.append(dimOperands.begin(), dimOperands.end());
4198   result.operands.append(symOperands.begin(), symOperands.end());
4199 
4200   Builder &builder = parser.getBuilder();
4201   auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4202                                 parser.getContext());
4203   flatMap = flatMap.replaceDimsAndSymbols(
4204       dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4205 
4206   result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4207   result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4208   return success();
4209 }
4210 
4211 //
4212 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4213 //               `to` parallel-bound steps? region attr-dict?
4214 // steps     ::= `steps` `(` integer-literals `)`
4215 //
4216 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4217                                     OperationState &result) {
4218   auto &builder = parser.getBuilder();
4219   auto indexType = builder.getIndexType();
4220   SmallVector<OpAsmParser::Argument, 4> ivs;
4221   if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
4222       parser.parseEqual() ||
4223       parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4224       parser.parseKeyword("to") ||
4225       parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4226     return failure();
4227 
4228   AffineMapAttr stepsMapAttr;
4229   NamedAttrList stepsAttrs;
4230   SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
4231   if (failed(parser.parseOptionalKeyword("step"))) {
4232     SmallVector<int64_t, 4> steps(ivs.size(), 1);
4233     result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4234                         builder.getI64ArrayAttr(steps));
4235   } else {
4236     if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4237                                       AffineParallelOp::getStepsAttrStrName(),
4238                                       stepsAttrs,
4239                                       OpAsmParser::Delimiter::Paren))
4240       return failure();
4241 
4242     // Convert steps from an AffineMap into an I64ArrayAttr.
4243     SmallVector<int64_t, 4> steps;
4244     auto stepsMap = stepsMapAttr.getValue();
4245     for (const auto &result : stepsMap.getResults()) {
4246       auto constExpr = dyn_cast<AffineConstantExpr>(result);
4247       if (!constExpr)
4248         return parser.emitError(parser.getNameLoc(),
4249                                 "steps must be constant integers");
4250       steps.push_back(constExpr.getValue());
4251     }
4252     result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4253                         builder.getI64ArrayAttr(steps));
4254   }
4255 
4256   // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4257   // quoted strings are a member of the enum AtomicRMWKind.
4258   SmallVector<Attribute, 4> reductions;
4259   if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4260     if (parser.parseLParen())
4261       return failure();
4262     auto parseAttributes = [&]() -> ParseResult {
4263       // Parse a single quoted string via the attribute parsing, and then
4264       // verify it is a member of the enum and convert to it's integer
4265       // representation.
4266       StringAttr attrVal;
4267       NamedAttrList attrStorage;
4268       auto loc = parser.getCurrentLocation();
4269       if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4270                                 attrStorage))
4271         return failure();
4272       std::optional<arith::AtomicRMWKind> reduction =
4273           arith::symbolizeAtomicRMWKind(attrVal.getValue());
4274       if (!reduction)
4275         return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4276       reductions.push_back(
4277           builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4278       // While we keep getting commas, keep parsing.
4279       return success();
4280     };
4281     if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4282       return failure();
4283   }
4284   result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4285                       builder.getArrayAttr(reductions));
4286 
4287   // Parse return types of reductions (if any)
4288   if (parser.parseOptionalArrowTypeList(result.types))
4289     return failure();
4290 
4291   // Now parse the body.
4292   Region *body = result.addRegion();
4293   for (auto &iv : ivs)
4294     iv.type = indexType;
4295   if (parser.parseRegion(*body, ivs) ||
4296       parser.parseOptionalAttrDict(result.attributes))
4297     return failure();
4298 
4299   // Add a terminator if none was parsed.
4300   AffineParallelOp::ensureTerminator(*body, builder, result.location);
4301   return success();
4302 }
4303 
4304 //===----------------------------------------------------------------------===//
4305 // AffineYieldOp
4306 //===----------------------------------------------------------------------===//
4307 
4308 LogicalResult AffineYieldOp::verify() {
4309   auto *parentOp = (*this)->getParentOp();
4310   auto results = parentOp->getResults();
4311   auto operands = getOperands();
4312 
4313   if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4314     return emitOpError() << "only terminates affine.if/for/parallel regions";
4315   if (parentOp->getNumResults() != getNumOperands())
4316     return emitOpError() << "parent of yield must have same number of "
4317                             "results as the yield operands";
4318   for (auto it : llvm::zip(results, operands)) {
4319     if (std::get<0>(it).getType() != std::get<1>(it).getType())
4320       return emitOpError() << "types mismatch between yield op and its parent";
4321   }
4322 
4323   return success();
4324 }
4325 
4326 //===----------------------------------------------------------------------===//
4327 // AffineVectorLoadOp
4328 //===----------------------------------------------------------------------===//
4329 
4330 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4331                                VectorType resultType, AffineMap map,
4332                                ValueRange operands) {
4333   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4334   result.addOperands(operands);
4335   if (map)
4336     result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4337   result.types.push_back(resultType);
4338 }
4339 
4340 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4341                                VectorType resultType, Value memref,
4342                                AffineMap map, ValueRange mapOperands) {
4343   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4344   result.addOperands(memref);
4345   result.addOperands(mapOperands);
4346   result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4347   result.types.push_back(resultType);
4348 }
4349 
4350 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4351                                VectorType resultType, Value memref,
4352                                ValueRange indices) {
4353   auto memrefType = llvm::cast<MemRefType>(memref.getType());
4354   int64_t rank = memrefType.getRank();
4355   // Create identity map for memrefs with at least one dimension or () -> ()
4356   // for zero-dimensional memrefs.
4357   auto map =
4358       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4359   build(builder, result, resultType, memref, map, indices);
4360 }
4361 
4362 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4363                                                      MLIRContext *context) {
4364   results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4365 }
4366 
4367 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4368                                       OperationState &result) {
4369   auto &builder = parser.getBuilder();
4370   auto indexTy = builder.getIndexType();
4371 
4372   MemRefType memrefType;
4373   VectorType resultType;
4374   OpAsmParser::UnresolvedOperand memrefInfo;
4375   AffineMapAttr mapAttr;
4376   SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4377   return failure(
4378       parser.parseOperand(memrefInfo) ||
4379       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4380                                     AffineVectorLoadOp::getMapAttrStrName(),
4381                                     result.attributes) ||
4382       parser.parseOptionalAttrDict(result.attributes) ||
4383       parser.parseColonType(memrefType) || parser.parseComma() ||
4384       parser.parseType(resultType) ||
4385       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4386       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4387       parser.addTypeToList(resultType, result.types));
4388 }
4389 
4390 void AffineVectorLoadOp::print(OpAsmPrinter &p) {
4391   p << " " << getMemRef() << '[';
4392   if (AffineMapAttr mapAttr =
4393           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4394     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4395   p << ']';
4396   p.printOptionalAttrDict((*this)->getAttrs(),
4397                           /*elidedAttrs=*/{getMapAttrStrName()});
4398   p << " : " << getMemRefType() << ", " << getType();
4399 }
4400 
4401 /// Verify common invariants of affine.vector_load and affine.vector_store.
4402 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4403                                           VectorType vectorType) {
4404   // Check that memref and vector element types match.
4405   if (memrefType.getElementType() != vectorType.getElementType())
4406     return op->emitOpError(
4407         "requires memref and vector types of the same elemental type");
4408   return success();
4409 }
4410 
4411 LogicalResult AffineVectorLoadOp::verify() {
4412   MemRefType memrefType = getMemRefType();
4413   if (failed(verifyMemoryOpIndexing(
4414           getOperation(),
4415           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4416           getMapOperands(), memrefType,
4417           /*numIndexOperands=*/getNumOperands() - 1)))
4418     return failure();
4419 
4420   if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4421     return failure();
4422 
4423   return success();
4424 }
4425 
4426 //===----------------------------------------------------------------------===//
4427 // AffineVectorStoreOp
4428 //===----------------------------------------------------------------------===//
4429 
4430 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4431                                 Value valueToStore, Value memref, AffineMap map,
4432                                 ValueRange mapOperands) {
4433   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4434   result.addOperands(valueToStore);
4435   result.addOperands(memref);
4436   result.addOperands(mapOperands);
4437   result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4438 }
4439 
4440 // Use identity map.
4441 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4442                                 Value valueToStore, Value memref,
4443                                 ValueRange indices) {
4444   auto memrefType = llvm::cast<MemRefType>(memref.getType());
4445   int64_t rank = memrefType.getRank();
4446   // Create identity map for memrefs with at least one dimension or () -> ()
4447   // for zero-dimensional memrefs.
4448   auto map =
4449       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4450   build(builder, result, valueToStore, memref, map, indices);
4451 }
4452 void AffineVectorStoreOp::getCanonicalizationPatterns(
4453     RewritePatternSet &results, MLIRContext *context) {
4454   results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4455 }
4456 
4457 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4458                                        OperationState &result) {
4459   auto indexTy = parser.getBuilder().getIndexType();
4460 
4461   MemRefType memrefType;
4462   VectorType resultType;
4463   OpAsmParser::UnresolvedOperand storeValueInfo;
4464   OpAsmParser::UnresolvedOperand memrefInfo;
4465   AffineMapAttr mapAttr;
4466   SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4467   return failure(
4468       parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4469       parser.parseOperand(memrefInfo) ||
4470       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4471                                     AffineVectorStoreOp::getMapAttrStrName(),
4472                                     result.attributes) ||
4473       parser.parseOptionalAttrDict(result.attributes) ||
4474       parser.parseColonType(memrefType) || parser.parseComma() ||
4475       parser.parseType(resultType) ||
4476       parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4477       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4478       parser.resolveOperands(mapOperands, indexTy, result.operands));
4479 }
4480 
4481 void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4482   p << " " << getValueToStore();
4483   p << ", " << getMemRef() << '[';
4484   if (AffineMapAttr mapAttr =
4485           (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4486     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4487   p << ']';
4488   p.printOptionalAttrDict((*this)->getAttrs(),
4489                           /*elidedAttrs=*/{getMapAttrStrName()});
4490   p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4491 }
4492 
4493 LogicalResult AffineVectorStoreOp::verify() {
4494   MemRefType memrefType = getMemRefType();
4495   if (failed(verifyMemoryOpIndexing(
4496           *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4497           getMapOperands(), memrefType,
4498           /*numIndexOperands=*/getNumOperands() - 2)))
4499     return failure();
4500 
4501   if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4502     return failure();
4503 
4504   return success();
4505 }
4506 
4507 //===----------------------------------------------------------------------===//
4508 // DelinearizeIndexOp
4509 //===----------------------------------------------------------------------===//
4510 
4511 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4512                                      OperationState &odsState,
4513                                      Value linearIndex, ValueRange dynamicBasis,
4514                                      ArrayRef<int64_t> staticBasis,
4515                                      bool hasOuterBound) {
4516   SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4517                                               : staticBasis.size() + 1,
4518                                 linearIndex.getType());
4519   build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4520         staticBasis);
4521 }
4522 
4523 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4524                                      OperationState &odsState,
4525                                      Value linearIndex, ValueRange basis,
4526                                      bool hasOuterBound) {
4527   if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4528     hasOuterBound = false;
4529     basis = basis.drop_front();
4530   }
4531   SmallVector<Value> dynamicBasis;
4532   SmallVector<int64_t> staticBasis;
4533   dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4534                              staticBasis);
4535   build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4536         hasOuterBound);
4537 }
4538 
4539 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4540                                      OperationState &odsState,
4541                                      Value linearIndex,
4542                                      ArrayRef<OpFoldResult> basis,
4543                                      bool hasOuterBound) {
4544   if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4545     hasOuterBound = false;
4546     basis = basis.drop_front();
4547   }
4548   SmallVector<Value> dynamicBasis;
4549   SmallVector<int64_t> staticBasis;
4550   dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4551   build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4552         hasOuterBound);
4553 }
4554 
4555 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4556                                      OperationState &odsState,
4557                                      Value linearIndex, ArrayRef<int64_t> basis,
4558                                      bool hasOuterBound) {
4559   build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4560 }
4561 
4562 LogicalResult AffineDelinearizeIndexOp::verify() {
4563   ArrayRef<int64_t> staticBasis = getStaticBasis();
4564   if (getNumResults() != staticBasis.size() &&
4565       getNumResults() != staticBasis.size() + 1)
4566     return emitOpError("should return an index for each basis element and up "
4567                        "to one extra index");
4568 
4569   auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4570   if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4571     return emitOpError(
4572         "mismatch between dynamic and static basis (kDynamic marker but no "
4573         "corresponding dynamic basis entry) -- this can only happen due to an "
4574         "incorrect fold/rewrite");
4575 
4576   if (!llvm::all_of(staticBasis, [](int64_t v) {
4577         return v > 0 || ShapedType::isDynamic(v);
4578       }))
4579     return emitOpError("no basis element may be statically non-positive");
4580 
4581   return success();
4582 }
4583 
4584 /// Given mixed basis of affine.delinearize_index/linearize_index replace
4585 /// constant SSA values with the constant integer value and return the new
4586 /// static basis. In case no such candidate for replacement exists, this utility
4587 /// returns std::nullopt.
4588 static std::optional<SmallVector<int64_t>>
4589 foldCstValueToCstAttrBasis(ArrayRef<OpFoldResult> mixedBasis,
4590                            MutableOperandRange mutableDynamicBasis,
4591                            ArrayRef<Attribute> dynamicBasis) {
4592   uint64_t dynamicBasisIndex = 0;
4593   for (OpFoldResult basis : dynamicBasis) {
4594     if (basis) {
4595       mutableDynamicBasis.erase(dynamicBasisIndex);
4596     } else {
4597       ++dynamicBasisIndex;
4598     }
4599   }
4600 
4601   // No constant SSA value exists.
4602   if (dynamicBasisIndex == dynamicBasis.size())
4603     return std::nullopt;
4604 
4605   SmallVector<int64_t> staticBasis;
4606   for (OpFoldResult basis : mixedBasis) {
4607     std::optional<int64_t> basisVal = getConstantIntValue(basis);
4608     if (!basisVal)
4609       staticBasis.push_back(ShapedType::kDynamic);
4610     else
4611       staticBasis.push_back(*basisVal);
4612   }
4613 
4614   return staticBasis;
4615 }
4616 
4617 LogicalResult
4618 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4619                                SmallVectorImpl<OpFoldResult> &result) {
4620   std::optional<SmallVector<int64_t>> maybeStaticBasis =
4621       foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4622                                  adaptor.getDynamicBasis());
4623   if (maybeStaticBasis) {
4624     setStaticBasis(*maybeStaticBasis);
4625     return success();
4626   }
4627   // If we won't be doing any division or modulo (no basis or the one basis
4628   // element is purely advisory), simply return the input value.
4629   if (getNumResults() == 1) {
4630     result.push_back(getLinearIndex());
4631     return success();
4632   }
4633 
4634   if (adaptor.getLinearIndex() == nullptr)
4635     return failure();
4636 
4637   if (!adaptor.getDynamicBasis().empty())
4638     return failure();
4639 
4640   int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4641   Type attrType = getLinearIndex().getType();
4642 
4643   ArrayRef<int64_t> staticBasis = getStaticBasis();
4644   if (hasOuterBound())
4645     staticBasis = staticBasis.drop_front();
4646   for (int64_t modulus : llvm::reverse(staticBasis)) {
4647     result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4648     highPart = llvm::divideFloorSigned(highPart, modulus);
4649   }
4650   result.push_back(IntegerAttr::get(attrType, highPart));
4651   std::reverse(result.begin(), result.end());
4652   return success();
4653 }
4654 
4655 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
4656   OpBuilder builder(getContext());
4657   if (hasOuterBound()) {
4658     if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4659       return getMixedValues(getStaticBasis().drop_front(),
4660                             getDynamicBasis().drop_front(), builder);
4661 
4662     return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4663                           builder);
4664   }
4665 
4666   return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4667 }
4668 
4669 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
4670   SmallVector<OpFoldResult> ret = getMixedBasis();
4671   if (!hasOuterBound())
4672     ret.insert(ret.begin(), OpFoldResult());
4673   return ret;
4674 }
4675 
4676 namespace {
4677 
4678 // Drops delinearization indices that correspond to unit-extent basis
4679 struct DropUnitExtentBasis
4680     : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4681   using OpRewritePattern::OpRewritePattern;
4682 
4683   LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4684                                 PatternRewriter &rewriter) const override {
4685     SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4686     std::optional<Value> zero = std::nullopt;
4687     Location loc = delinearizeOp->getLoc();
4688     auto getZero = [&]() -> Value {
4689       if (!zero)
4690         zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4691       return zero.value();
4692     };
4693 
4694     // Replace all indices corresponding to unit-extent basis with 0.
4695     // Remaining basis can be used to get a new `affine.delinearize_index` op.
4696     SmallVector<OpFoldResult> newBasis;
4697     for (auto [index, basis] :
4698          llvm::enumerate(delinearizeOp.getPaddedBasis())) {
4699       std::optional<int64_t> basisVal =
4700           basis ? getConstantIntValue(basis) : std::nullopt;
4701       if (basisVal && *basisVal == 1)
4702         replacements[index] = getZero();
4703       else
4704         newBasis.push_back(basis);
4705     }
4706 
4707     if (newBasis.size() == delinearizeOp.getNumResults())
4708       return rewriter.notifyMatchFailure(delinearizeOp,
4709                                          "no unit basis elements");
4710 
4711     if (!newBasis.empty()) {
4712       // Will drop the leading nullptr from `basis` if there was no outer bound.
4713       auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4714           loc, delinearizeOp.getLinearIndex(), newBasis);
4715       int newIndex = 0;
4716       // Map back the new delinearized indices to the values they replace.
4717       for (auto &replacement : replacements) {
4718         if (replacement)
4719           continue;
4720         replacement = newDelinearizeOp->getResult(newIndex++);
4721       }
4722     }
4723 
4724     rewriter.replaceOp(delinearizeOp, replacements);
4725     return success();
4726   }
4727 };
4728 
4729 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
4730 /// disjoint` and the two operations end with the same basis elements,
4731 /// cancel those parts of the operations out because they are inverses
4732 /// of each other.
4733 ///
4734 /// If the operations have the same basis, cancel them entirely.
4735 ///
4736 /// The `disjoint` flag is needed on the `affine.linearize_index` because
4737 /// otherwise, there is no guarantee that the inputs to the linearization are
4738 /// in-bounds the way the outputs of the delinearization would be.
4739 struct CancelDelinearizeOfLinearizeDisjointExactTail
4740     : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4741   using OpRewritePattern::OpRewritePattern;
4742 
4743   LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4744                                 PatternRewriter &rewriter) const override {
4745     auto linearizeOp = delinearizeOp.getLinearIndex()
4746                            .getDefiningOp<affine::AffineLinearizeIndexOp>();
4747     if (!linearizeOp)
4748       return rewriter.notifyMatchFailure(delinearizeOp,
4749                                          "index doesn't come from linearize");
4750 
4751     if (!linearizeOp.getDisjoint())
4752       return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
4753 
4754     ValueRange linearizeIns = linearizeOp.getMultiIndex();
4755     // Note: we use the full basis so we don't lose outer bounds later.
4756     SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
4757     SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
4758     size_t numMatches = 0;
4759     for (auto [linSize, delinSize] : llvm::zip(
4760              llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4761       if (linSize != delinSize)
4762         break;
4763       ++numMatches;
4764     }
4765 
4766     if (numMatches == 0)
4767       return rewriter.notifyMatchFailure(
4768           delinearizeOp, "final basis element doesn't match linearize");
4769 
4770     // The easy case: everything lines up and the basis match sup completely.
4771     if (numMatches == linearizeBasis.size() &&
4772         numMatches == delinearizeBasis.size() &&
4773         linearizeIns.size() == delinearizeOp.getNumResults()) {
4774       rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4775       return success();
4776     }
4777 
4778     Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
4779         linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4780         ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
4781         linearizeOp.getDisjoint());
4782     auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
4783         delinearizeOp.getLoc(), newLinearize,
4784         ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
4785         delinearizeOp.hasOuterBound());
4786     SmallVector<Value> mergedResults(newDelinearize.getResults());
4787     mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4788                          linearizeIns.take_back(numMatches).end());
4789     rewriter.replaceOp(delinearizeOp, mergedResults);
4790     return success();
4791   }
4792 };
4793 
4794 /// If the input to a delinearization is a disjoint linearization, and the
4795 /// last k > 1 components of the delinearization basis multiply to the
4796 /// last component of the linearization basis, break the linearization and
4797 /// delinearization into two parts, peeling off the last input to linearization.
4798 ///
4799 /// For example:
4800 ///    %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
4801 ///    %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
4802 /// becomes
4803 ///    %0 = affine.linearize_index [%z, %y] by (3, 2) : index
4804 ///    %1:2 = affine.delinearize_index %0 by (2, 3) : index
4805 ///    %2:2 = affine.delinearize_index %x by (8, 4) : index
4806 /// where the original %1:4 is replaced by %1:2 ++ %2:2
4807 struct SplitDelinearizeSpanningLastLinearizeArg final
4808     : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4809   using OpRewritePattern::OpRewritePattern;
4810 
4811   LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4812                                 PatternRewriter &rewriter) const override {
4813     auto linearizeOp = delinearizeOp.getLinearIndex()
4814                            .getDefiningOp<affine::AffineLinearizeIndexOp>();
4815     if (!linearizeOp)
4816       return rewriter.notifyMatchFailure(delinearizeOp,
4817                                          "index doesn't come from linearize");
4818 
4819     if (!linearizeOp.getDisjoint())
4820       return rewriter.notifyMatchFailure(linearizeOp,
4821                                          "linearize isn't disjoint");
4822 
4823     int64_t target = linearizeOp.getStaticBasis().back();
4824     if (ShapedType::isDynamic(target))
4825       return rewriter.notifyMatchFailure(
4826           linearizeOp, "linearize ends with dynamic basis value");
4827 
4828     int64_t sizeToSplit = 1;
4829     size_t elemsToSplit = 0;
4830     ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
4831     for (int64_t basisElem : llvm::reverse(basis)) {
4832       if (ShapedType::isDynamic(basisElem))
4833         return rewriter.notifyMatchFailure(
4834             delinearizeOp, "dynamic basis element while scanning for split");
4835       sizeToSplit *= basisElem;
4836       elemsToSplit += 1;
4837 
4838       if (sizeToSplit > target)
4839         return rewriter.notifyMatchFailure(delinearizeOp,
4840                                            "overshot last argument size");
4841       if (sizeToSplit == target)
4842         break;
4843     }
4844 
4845     if (sizeToSplit < target)
4846       return rewriter.notifyMatchFailure(
4847           delinearizeOp, "product of known basis elements doesn't exceed last "
4848                          "linearize argument");
4849 
4850     if (elemsToSplit < 2)
4851       return rewriter.notifyMatchFailure(
4852           delinearizeOp,
4853           "need at least two elements to form the basis product");
4854 
4855     Value linearizeWithoutBack =
4856         rewriter.create<affine::AffineLinearizeIndexOp>(
4857             linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4858             linearizeOp.getDynamicBasis(),
4859             linearizeOp.getStaticBasis().drop_back(),
4860             linearizeOp.getDisjoint());
4861     auto delinearizeWithoutSplitPart =
4862         rewriter.create<affine::AffineDelinearizeIndexOp>(
4863             delinearizeOp.getLoc(), linearizeWithoutBack,
4864             delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4865             delinearizeOp.hasOuterBound());
4866     auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
4867         delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4868         basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
4869     SmallVector<Value> results = llvm::to_vector(
4870         llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4871                             delinearizeBack.getResults()));
4872     rewriter.replaceOp(delinearizeOp, results);
4873 
4874     return success();
4875   }
4876 };
4877 } // namespace
4878 
4879 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4880     RewritePatternSet &patterns, MLIRContext *context) {
4881   patterns
4882       .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4883               DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4884           context);
4885 }
4886 
4887 //===----------------------------------------------------------------------===//
4888 // LinearizeIndexOp
4889 //===----------------------------------------------------------------------===//
4890 
4891 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4892                                    OperationState &odsState,
4893                                    ValueRange multiIndex, ValueRange basis,
4894                                    bool disjoint) {
4895   if (!basis.empty() && basis.front() == Value())
4896     basis = basis.drop_front();
4897   SmallVector<Value> dynamicBasis;
4898   SmallVector<int64_t> staticBasis;
4899   dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4900                              staticBasis);
4901   build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4902 }
4903 
4904 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4905                                    OperationState &odsState,
4906                                    ValueRange multiIndex,
4907                                    ArrayRef<OpFoldResult> basis,
4908                                    bool disjoint) {
4909   if (!basis.empty() && basis.front() == OpFoldResult())
4910     basis = basis.drop_front();
4911   SmallVector<Value> dynamicBasis;
4912   SmallVector<int64_t> staticBasis;
4913   dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4914   build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4915 }
4916 
4917 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4918                                    OperationState &odsState,
4919                                    ValueRange multiIndex,
4920                                    ArrayRef<int64_t> basis, bool disjoint) {
4921   build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
4922 }
4923 
4924 LogicalResult AffineLinearizeIndexOp::verify() {
4925   size_t numIndexes = getMultiIndex().size();
4926   size_t numBasisElems = getStaticBasis().size();
4927   if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
4928     return emitOpError("should be passed a basis element for each index except "
4929                        "possibly the first");
4930 
4931   auto dynamicMarkersCount =
4932       llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4933   if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4934     return emitOpError(
4935         "mismatch between dynamic and static basis (kDynamic marker but no "
4936         "corresponding dynamic basis entry) -- this can only happen due to an "
4937         "incorrect fold/rewrite");
4938 
4939   return success();
4940 }
4941 
4942 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
4943   std::optional<SmallVector<int64_t>> maybeStaticBasis =
4944       foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4945                                  adaptor.getDynamicBasis());
4946   if (maybeStaticBasis) {
4947     setStaticBasis(*maybeStaticBasis);
4948     return getResult();
4949   }
4950   // No indices linearizes to zero.
4951   if (getMultiIndex().empty())
4952     return IntegerAttr::get(getResult().getType(), 0);
4953 
4954   // One single index linearizes to itself.
4955   if (getMultiIndex().size() == 1)
4956     return getMultiIndex().front();
4957 
4958   if (llvm::any_of(adaptor.getMultiIndex(),
4959                    [](Attribute a) { return a == nullptr; }))
4960     return nullptr;
4961 
4962   if (!adaptor.getDynamicBasis().empty())
4963     return nullptr;
4964 
4965   int64_t result = 0;
4966   int64_t stride = 1;
4967   for (auto [length, indexAttr] :
4968        llvm::zip_first(llvm::reverse(getStaticBasis()),
4969                        llvm::reverse(adaptor.getMultiIndex()))) {
4970     result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
4971     stride = stride * length;
4972   }
4973   // Handle the index element with no basis element.
4974   if (!hasOuterBound())
4975     result =
4976         result +
4977         cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
4978 
4979   return IntegerAttr::get(getResult().getType(), result);
4980 }
4981 
4982 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
4983   OpBuilder builder(getContext());
4984   if (hasOuterBound()) {
4985     if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4986       return getMixedValues(getStaticBasis().drop_front(),
4987                             getDynamicBasis().drop_front(), builder);
4988 
4989     return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4990                           builder);
4991   }
4992 
4993   return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4994 }
4995 
4996 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
4997   SmallVector<OpFoldResult> ret = getMixedBasis();
4998   if (!hasOuterBound())
4999     ret.insert(ret.begin(), OpFoldResult());
5000   return ret;
5001 }
5002 
5003 namespace {
5004 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5005 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5006 /// %...d)`.
5007 
5008 /// Note that `disjoint` is required here, because, without it, we could have
5009 /// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5010 /// is a valid operation where the `%c64` cannot be trivially dropped.
5011 ///
5012 /// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5013 /// the operation isn't asserted to be `disjoint`.
5014 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5015     : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5016   using OpRewritePattern::OpRewritePattern;
5017 
5018   LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5019                                 PatternRewriter &rewriter) const override {
5020     ValueRange multiIndex = op.getMultiIndex();
5021     size_t numIndices = multiIndex.size();
5022     SmallVector<Value> newIndices;
5023     newIndices.reserve(numIndices);
5024     SmallVector<OpFoldResult> newBasis;
5025     newBasis.reserve(numIndices);
5026 
5027     if (!op.hasOuterBound()) {
5028       newIndices.push_back(multiIndex.front());
5029       multiIndex = multiIndex.drop_front();
5030     }
5031 
5032     SmallVector<OpFoldResult> basis = op.getMixedBasis();
5033     for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5034       std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5035       if (!basisEntry || *basisEntry != 1) {
5036         newIndices.push_back(index);
5037         newBasis.push_back(basisElem);
5038         continue;
5039       }
5040 
5041       std::optional<int64_t> indexValue = getConstantIntValue(index);
5042       if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5043         newIndices.push_back(index);
5044         newBasis.push_back(basisElem);
5045         continue;
5046       }
5047     }
5048     if (newIndices.size() == numIndices)
5049       return rewriter.notifyMatchFailure(op,
5050                                          "no unit basis entries to replace");
5051 
5052     if (newIndices.size() == 0) {
5053       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
5054       return success();
5055     }
5056     rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5057         op, newIndices, newBasis, op.getDisjoint());
5058     return success();
5059   }
5060 };
5061 
5062 /// Return the product of `terms`, creating an `affine.apply` if any of them are
5063 /// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
5064 static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
5065                                    ArrayRef<OpFoldResult> terms) {
5066   int64_t nDynamic = 0;
5067   SmallVector<Value> dynamicPart;
5068   AffineExpr result = builder.getAffineConstantExpr(1);
5069   for (OpFoldResult term : terms) {
5070     if (!term)
5071       return term;
5072     std::optional<int64_t> maybeConst = getConstantIntValue(term);
5073     if (maybeConst) {
5074       result = result * builder.getAffineConstantExpr(*maybeConst);
5075     } else {
5076       dynamicPart.push_back(cast<Value>(term));
5077       result = result * builder.getAffineSymbolExpr(nDynamic++);
5078     }
5079   }
5080   if (auto constant = dyn_cast<AffineConstantExpr>(result))
5081     return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5082   return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5083 }
5084 
5085 /// If conseceutive outputs of a delinearize_index are linearized with the same
5086 /// bounds, canonicalize away the redundant arithmetic.
5087 ///
5088 /// That is, if we have
5089 /// ```
5090 /// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5091 /// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5092 ///   by (...e, B1, B2, ..., BK, ...f)
5093 /// ```
5094 ///
5095 /// We can rewrite this to
5096 /// ```
5097 /// B = B1 * B2 ... BK
5098 /// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5099 /// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5100 /// ```
5101 /// where we replace all results of %s unaffected by the change with results
5102 /// from %sMerged.
5103 ///
5104 /// As a special case, if all results of the delinearize are merged in this way
5105 /// we can replace those usages with %x, thus cancelling the delinearization
5106 /// entirely, as in
5107 /// ```
5108 /// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5109 /// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5110 /// ```
5111 /// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5112 struct CancelLinearizeOfDelinearizePortion final
5113     : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5114   using OpRewritePattern::OpRewritePattern;
5115 
5116 private:
5117   // Struct representing a case where the cancellation pattern
5118   // applies. A `Match` means that `length` inputs to the linearize operation
5119   // starting at `linStart` can be cancelled with `length` outputs of
5120   // `delinearize`, starting from `delinStart`.
5121   struct Match {
5122     AffineDelinearizeIndexOp delinearize;
5123     unsigned linStart = 0;
5124     unsigned delinStart = 0;
5125     unsigned length = 0;
5126   };
5127 
5128 public:
5129   LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5130                                 PatternRewriter &rewriter) const override {
5131     SmallVector<Match> matches;
5132 
5133     const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5134     ArrayRef<OpFoldResult> linBasisRef = linBasis;
5135 
5136     ValueRange multiIndex = linearizeOp.getMultiIndex();
5137     unsigned numLinArgs = multiIndex.size();
5138     unsigned linArgIdx = 0;
5139     // We only want to replace one run from the same delinearize op per
5140     // pattern invocation lest we run into invalidation issues.
5141     llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5142     while (linArgIdx < numLinArgs) {
5143       auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5144       if (!asResult) {
5145         linArgIdx++;
5146         continue;
5147       }
5148 
5149       auto delinearizeOp =
5150           dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5151       if (!delinearizeOp) {
5152         linArgIdx++;
5153         continue;
5154       }
5155 
5156       /// Result 0 of the delinearize and argument 0 of the linearize can
5157       /// leave their maximum value unspecified. However, even if this happens
5158       /// we can still sometimes start the match process. Specifically, if
5159       /// - The argument we're matching is result 0 and argument 0 (so the
5160       /// bounds don't matter). For example,
5161       ///
5162       ///     %0:2 = affine.delinearize_index %x into (8) : index, index
5163       ///     %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5164       /// allows cancellation
5165       /// - The delinearization doesn't specify a bound, but the linearization
5166       ///  is `disjoint`, which asserts that the bound on the linearization is
5167       ///  correct.
5168       unsigned delinArgIdx = asResult.getResultNumber();
5169       SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5170       OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5171       OpFoldResult firstLinBound = linBasis[linArgIdx];
5172       bool boundsMatch = firstDelinBound == firstLinBound;
5173       bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5174       bool knownByDisjoint =
5175           linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5176       if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5177         linArgIdx++;
5178         continue;
5179       }
5180 
5181       unsigned j = 1;
5182       unsigned numDelinOuts = delinearizeOp.getNumResults();
5183       for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5184            ++j) {
5185         if (multiIndex[linArgIdx + j] !=
5186             delinearizeOp.getResult(delinArgIdx + j))
5187           break;
5188         if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5189           break;
5190       }
5191       // If there're multiple matches against the same delinearize_index,
5192       // only rewrite the first one we find to prevent invalidations. The next
5193       // ones will be taken care of by subsequent pattern invocations.
5194       if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5195         linArgIdx++;
5196         continue;
5197       }
5198       matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5199       linArgIdx += j;
5200     }
5201 
5202     if (matches.empty())
5203       return rewriter.notifyMatchFailure(
5204           linearizeOp, "no run of delinearize outputs to deal with");
5205 
5206     // Record all the delinearize replacements so we can do them after creating
5207     // the new linearization operation, since the new operation might use
5208     // outputs of something we're replacing.
5209     SmallVector<SmallVector<Value>> delinearizeReplacements;
5210 
5211     SmallVector<Value> newIndex;
5212     newIndex.reserve(numLinArgs);
5213     SmallVector<OpFoldResult> newBasis;
5214     newBasis.reserve(numLinArgs);
5215     unsigned prevMatchEnd = 0;
5216     for (Match m : matches) {
5217       unsigned gap = m.linStart - prevMatchEnd;
5218       llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5219       llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5220       // Update here so we don't forget this during early continues
5221       prevMatchEnd = m.linStart + m.length;
5222 
5223       PatternRewriter::InsertionGuard g(rewriter);
5224       rewriter.setInsertionPoint(m.delinearize);
5225 
5226       ArrayRef<OpFoldResult> basisToMerge =
5227           linBasisRef.slice(m.linStart, m.length);
5228       // We use the slice from the linearize's basis above because of the
5229       // "bounds inferred from `disjoint`" case above.
5230       OpFoldResult newSize =
5231           computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5232 
5233       // Trivial case where we can just skip past the delinearize all together
5234       if (m.length == m.delinearize.getNumResults()) {
5235         newIndex.push_back(m.delinearize.getLinearIndex());
5236         newBasis.push_back(newSize);
5237         // Pad out set of replacements so we don't do anything with this one.
5238         delinearizeReplacements.push_back(SmallVector<Value>());
5239         continue;
5240       }
5241 
5242       SmallVector<Value> newDelinResults;
5243       SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5244       newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5245                           newDelinBasis.begin() + m.delinStart + m.length);
5246       newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5247       auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5248           m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5249           newDelinBasis);
5250 
5251       // Since there may be other uses of the indices we just merged together,
5252       // create a residual affine.delinearize_index that delinearizes the
5253       // merged output into its component parts.
5254       Value combinedElem = newDelinearize.getResult(m.delinStart);
5255       auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5256           m.delinearize.getLoc(), combinedElem, basisToMerge);
5257 
5258       // Swap all the uses of the unaffected delinearize outputs to the new
5259       // delinearization so that the old code can be removed if this
5260       // linearize_index is the only user of the merged results.
5261       llvm::append_range(newDelinResults,
5262                          newDelinearize.getResults().take_front(m.delinStart));
5263       llvm::append_range(newDelinResults, residualDelinearize.getResults());
5264       llvm::append_range(
5265           newDelinResults,
5266           newDelinearize.getResults().drop_front(m.delinStart + 1));
5267 
5268       delinearizeReplacements.push_back(newDelinResults);
5269       newIndex.push_back(combinedElem);
5270       newBasis.push_back(newSize);
5271     }
5272     llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5273     llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5274     rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5275         linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5276 
5277     for (auto [m, newResults] :
5278          llvm::zip_equal(matches, delinearizeReplacements)) {
5279       if (newResults.empty())
5280         continue;
5281       rewriter.replaceOp(m.delinearize, newResults);
5282     }
5283 
5284     return success();
5285   }
5286 };
5287 
5288 /// Strip leading zero from affine.linearize_index.
5289 ///
5290 /// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5291 /// to `affine.linearize_index [...a] by (...b)` in all cases.
5292 struct DropLinearizeLeadingZero final
5293     : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5294   using OpRewritePattern::OpRewritePattern;
5295 
5296   LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5297                                 PatternRewriter &rewriter) const override {
5298     Value leadingIdx = op.getMultiIndex().front();
5299     if (!matchPattern(leadingIdx, m_Zero()))
5300       return failure();
5301 
5302     if (op.getMultiIndex().size() == 1) {
5303       rewriter.replaceOp(op, leadingIdx);
5304       return success();
5305     }
5306 
5307     SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5308     ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5309     if (op.hasOuterBound())
5310       newMixedBasis = newMixedBasis.drop_front();
5311 
5312     rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5313         op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5314     return success();
5315   }
5316 };
5317 } // namespace
5318 
5319 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5320     RewritePatternSet &patterns, MLIRContext *context) {
5321   patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5322                DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5323 }
5324 
5325 //===----------------------------------------------------------------------===//
5326 // TableGen'd op method definitions
5327 //===----------------------------------------------------------------------===//
5328 
5329 #define GET_OP_CLASSES
5330 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
5331