xref: /llvm-project/mlir/test/lib/Dialect/Test/TestOpDefs.cpp (revision 3c64f86314fbf9a3cd578419f16e621a4de57eaa)
1 //===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestOps.h"
11 #include "mlir/Dialect/Tensor/IR/Tensor.h"
12 #include "mlir/IR/Verifier.h"
13 #include "mlir/Interfaces/FunctionImplementation.h"
14 #include "mlir/Interfaces/MemorySlotInterfaces.h"
15 
16 using namespace mlir;
17 using namespace test;
18 
19 //===----------------------------------------------------------------------===//
20 // TestBranchOp
21 //===----------------------------------------------------------------------===//
22 
23 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
24   assert(index == 0 && "invalid successor index");
25   return SuccessorOperands(getTargetOperandsMutable());
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // TestProducingBranchOp
30 //===----------------------------------------------------------------------===//
31 
32 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
33   assert(index <= 1 && "invalid successor index");
34   if (index == 1)
35     return SuccessorOperands(getFirstOperandsMutable());
36   return SuccessorOperands(getSecondOperandsMutable());
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // TestInternalBranchOp
41 //===----------------------------------------------------------------------===//
42 
43 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
44   assert(index <= 1 && "invalid successor index");
45   if (index == 0)
46     return SuccessorOperands(0, getSuccessOperandsMutable());
47   return SuccessorOperands(1, getErrorOperandsMutable());
48 }
49 
50 //===----------------------------------------------------------------------===//
51 // TestCallOp
52 //===----------------------------------------------------------------------===//
53 
54 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
55   // Check that the callee attribute was specified.
56   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
57   if (!fnAttr)
58     return emitOpError("requires a 'callee' symbol reference attribute");
59   if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
60     return emitOpError() << "'" << fnAttr.getValue()
61                          << "' does not reference a valid function";
62   return success();
63 }
64 
65 //===----------------------------------------------------------------------===//
66 // FoldToCallOp
67 //===----------------------------------------------------------------------===//
68 
69 namespace {
70 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
71   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
72 
73   LogicalResult matchAndRewrite(FoldToCallOp op,
74                                 PatternRewriter &rewriter) const override {
75     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
76                                               op.getCalleeAttr(), ValueRange());
77     return success();
78   }
79 };
80 } // namespace
81 
82 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
83                                                MLIRContext *context) {
84   results.add<FoldToCallOpPattern>(context);
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // IsolatedRegionOp - test parsing passthrough operands
89 //===----------------------------------------------------------------------===//
90 
91 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
92                                     OperationState &result) {
93   // Parse the input operand.
94   OpAsmParser::Argument argInfo;
95   argInfo.type = parser.getBuilder().getIndexType();
96   if (parser.parseOperand(argInfo.ssaName) ||
97       parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
98     return failure();
99 
100   // Parse the body region, and reuse the operand info as the argument info.
101   Region *body = result.addRegion();
102   return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
103 }
104 
105 void IsolatedRegionOp::print(OpAsmPrinter &p) {
106   p << ' ';
107   p.printOperand(getOperand());
108   p.shadowRegionArgs(getRegion(), getOperand());
109   p << ' ';
110   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // SSACFGRegionOp
115 //===----------------------------------------------------------------------===//
116 
117 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
118   return RegionKind::SSACFG;
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // GraphRegionOp
123 //===----------------------------------------------------------------------===//
124 
125 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
126   return RegionKind::Graph;
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // IsolatedGraphRegionOp
131 //===----------------------------------------------------------------------===//
132 
133 RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) {
134   return RegionKind::Graph;
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // AffineScopeOp
139 //===----------------------------------------------------------------------===//
140 
141 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
142   // Parse the body region, and reuse the operand info as the argument info.
143   Region *body = result.addRegion();
144   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
145 }
146 
147 void AffineScopeOp::print(OpAsmPrinter &p) {
148   p << " ";
149   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // TestRemoveOpWithInnerOps
154 //===----------------------------------------------------------------------===//
155 
156 namespace {
157 struct TestRemoveOpWithInnerOps
158     : public OpRewritePattern<TestOpWithRegionPattern> {
159   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
160 
161   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
162 
163   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
164                                 PatternRewriter &rewriter) const override {
165     rewriter.eraseOp(op);
166     return success();
167   }
168 };
169 } // namespace
170 
171 //===----------------------------------------------------------------------===//
172 // TestOpWithRegionPattern
173 //===----------------------------------------------------------------------===//
174 
175 void TestOpWithRegionPattern::getCanonicalizationPatterns(
176     RewritePatternSet &results, MLIRContext *context) {
177   results.add<TestRemoveOpWithInnerOps>(context);
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // TestOpWithRegionFold
182 //===----------------------------------------------------------------------===//
183 
184 OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) {
185   return getOperand();
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // TestOpConstant
190 //===----------------------------------------------------------------------===//
191 
192 OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
193 
194 //===----------------------------------------------------------------------===//
195 // TestOpWithVariadicResultsAndFolder
196 //===----------------------------------------------------------------------===//
197 
198 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
199     FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) {
200   for (Value input : this->getOperands()) {
201     results.push_back(input);
202   }
203   return success();
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // TestOpInPlaceFold
208 //===----------------------------------------------------------------------===//
209 
210 OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
211   // Exercise the fact that an operation created with createOrFold should be
212   // allowed to access its parent block.
213   assert(getOperation()->getBlock() &&
214          "expected that operation is not unlinked");
215 
216   if (adaptor.getOp() && !getProperties().attr) {
217     // The folder adds "attr" if not present.
218     getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp());
219     return getResult();
220   }
221   return {};
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // OpWithInferTypeInterfaceOp
226 //===----------------------------------------------------------------------===//
227 
228 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
229     MLIRContext *, std::optional<Location> location, ValueRange operands,
230     DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
231     SmallVectorImpl<Type> &inferredReturnTypes) {
232   if (operands[0].getType() != operands[1].getType()) {
233     return emitOptionalError(location, "operand type mismatch ",
234                              operands[0].getType(), " vs ",
235                              operands[1].getType());
236   }
237   inferredReturnTypes.assign({operands[0].getType()});
238   return success();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // OpWithShapedTypeInferTypeInterfaceOp
243 //===----------------------------------------------------------------------===//
244 
245 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
246     MLIRContext *context, std::optional<Location> location,
247     ValueShapeRange operands, DictionaryAttr attributes,
248     OpaqueProperties properties, RegionRange regions,
249     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
250   // Create return type consisting of the last element of the first operand.
251   auto operandType = operands.front().getType();
252   auto sval = dyn_cast<ShapedType>(operandType);
253   if (!sval)
254     return emitOptionalError(location, "only shaped type operands allowed");
255   int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
256   auto type = IntegerType::get(context, 17);
257 
258   Attribute encoding;
259   if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
260     encoding = rankedTy.getEncoding();
261   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
262   return success();
263 }
264 
265 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
266     OpBuilder &builder, ValueRange operands,
267     llvm::SmallVectorImpl<Value> &shapes) {
268   shapes = SmallVector<Value, 1>{
269       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
270   return success();
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // OpWithResultShapeInterfaceOp
275 //===----------------------------------------------------------------------===//
276 
277 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
278     OpBuilder &builder, ValueRange operands,
279     llvm::SmallVectorImpl<Value> &shapes) {
280   Location loc = getLoc();
281   shapes.reserve(operands.size());
282   for (Value operand : llvm::reverse(operands)) {
283     auto rank = cast<RankedTensorType>(operand.getType()).getRank();
284     auto currShape = llvm::to_vector<4>(
285         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
286           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
287         }));
288     shapes.push_back(builder.create<tensor::FromElementsOp>(
289         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
290         currShape));
291   }
292   return success();
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // OpWithResultShapePerDimInterfaceOp
297 //===----------------------------------------------------------------------===//
298 
299 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
300     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
301   Location loc = getLoc();
302   shapes.reserve(getNumOperands());
303   for (Value operand : llvm::reverse(getOperands())) {
304     auto tensorType = cast<RankedTensorType>(operand.getType());
305     auto currShape = llvm::to_vector<4>(llvm::map_range(
306         llvm::seq<int64_t>(0, tensorType.getRank()),
307         [&](int64_t dim) -> OpFoldResult {
308           return tensorType.isDynamicDim(dim)
309                      ? static_cast<OpFoldResult>(
310                            builder.createOrFold<tensor::DimOp>(loc, operand,
311                                                                dim))
312                      : static_cast<OpFoldResult>(
313                            builder.getIndexAttr(tensorType.getDimSize(dim)));
314         }));
315     shapes.emplace_back(std::move(currShape));
316   }
317   return success();
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // SideEffectOp
322 //===----------------------------------------------------------------------===//
323 
324 namespace {
325 /// A test resource for side effects.
326 struct TestResource : public SideEffects::Resource::Base<TestResource> {
327   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
328 
329   StringRef getName() final { return "<Test>"; }
330 };
331 } // namespace
332 
333 void SideEffectOp::getEffects(
334     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
335   // Check for an effects attribute on the op instance.
336   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
337   if (!effectsAttr)
338     return;
339 
340   for (Attribute element : effectsAttr) {
341     DictionaryAttr effectElement = cast<DictionaryAttr>(element);
342 
343     // Get the specific memory effect.
344     MemoryEffects::Effect *effect =
345         StringSwitch<MemoryEffects::Effect *>(
346             cast<StringAttr>(effectElement.get("effect")).getValue())
347             .Case("allocate", MemoryEffects::Allocate::get())
348             .Case("free", MemoryEffects::Free::get())
349             .Case("read", MemoryEffects::Read::get())
350             .Case("write", MemoryEffects::Write::get());
351 
352     // Check for a non-default resource to use.
353     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
354     if (effectElement.get("test_resource"))
355       resource = TestResource::get();
356 
357     // Check for a result to affect.
358     if (effectElement.get("on_result"))
359       effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
360     else if (Attribute ref = effectElement.get("on_reference"))
361       effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
362     else
363       effects.emplace_back(effect, resource);
364   }
365 }
366 
367 void SideEffectOp::getEffects(
368     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
369   testSideEffectOpGetEffect(getOperation(), effects);
370 }
371 
372 void SideEffectWithRegionOp::getEffects(
373     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
374   // Check for an effects attribute on the op instance.
375   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
376   if (!effectsAttr)
377     return;
378 
379   for (Attribute element : effectsAttr) {
380     DictionaryAttr effectElement = cast<DictionaryAttr>(element);
381 
382     // Get the specific memory effect.
383     MemoryEffects::Effect *effect =
384         StringSwitch<MemoryEffects::Effect *>(
385             cast<StringAttr>(effectElement.get("effect")).getValue())
386             .Case("allocate", MemoryEffects::Allocate::get())
387             .Case("free", MemoryEffects::Free::get())
388             .Case("read", MemoryEffects::Read::get())
389             .Case("write", MemoryEffects::Write::get());
390 
391     // Check for a non-default resource to use.
392     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
393     if (effectElement.get("test_resource"))
394       resource = TestResource::get();
395 
396     // Check for a result to affect.
397     if (effectElement.get("on_result"))
398       effects.emplace_back(effect, getOperation()->getOpResults()[0], resource);
399     else if (effectElement.get("on_operand"))
400       effects.emplace_back(effect, &getOperation()->getOpOperands()[0],
401                            resource);
402     else if (effectElement.get("on_argument"))
403       effects.emplace_back(effect, getOperation()->getRegion(0).getArgument(0),
404                            resource);
405     else if (Attribute ref = effectElement.get("on_reference"))
406       effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource);
407     else
408       effects.emplace_back(effect, resource);
409   }
410 }
411 
412 void SideEffectWithRegionOp::getEffects(
413     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
414   testSideEffectOpGetEffect(getOperation(), effects);
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // StringAttrPrettyNameOp
419 //===----------------------------------------------------------------------===//
420 
421 // This op has fancy handling of its SSA result name.
422 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
423                                           OperationState &result) {
424   // Add the result types.
425   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
426     result.addTypes(parser.getBuilder().getIntegerType(32));
427 
428   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
429     return failure();
430 
431   // If the attribute dictionary contains no 'names' attribute, infer it from
432   // the SSA name (if specified).
433   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
434     return attr.getName() == "names";
435   });
436 
437   // If there was no name specified, check to see if there was a useful name
438   // specified in the asm file.
439   if (hadNames || parser.getNumResults() == 0)
440     return success();
441 
442   SmallVector<StringRef, 4> names;
443   auto *context = result.getContext();
444 
445   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
446     auto resultName = parser.getResultName(i);
447     StringRef nameStr;
448     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
449       nameStr = resultName.first;
450 
451     names.push_back(nameStr);
452   }
453 
454   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
455   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
456   return success();
457 }
458 
459 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
460   // Note that we only need to print the "name" attribute if the asmprinter
461   // result name disagrees with it.  This can happen in strange cases, e.g.
462   // when there are conflicts.
463   bool namesDisagree = getNames().size() != getNumResults();
464 
465   SmallString<32> resultNameStr;
466   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
467     resultNameStr.clear();
468     llvm::raw_svector_ostream tmpStream(resultNameStr);
469     p.printOperand(getResult(i), tmpStream);
470 
471     auto expectedName = dyn_cast<StringAttr>(getNames()[i]);
472     if (!expectedName ||
473         tmpStream.str().drop_front() != expectedName.getValue()) {
474       namesDisagree = true;
475     }
476   }
477 
478   if (namesDisagree)
479     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
480   else
481     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
482 }
483 
484 // We set the SSA name in the asm syntax to the contents of the name
485 // attribute.
486 void StringAttrPrettyNameOp::getAsmResultNames(
487     function_ref<void(Value, StringRef)> setNameFn) {
488 
489   auto value = getNames();
490   for (size_t i = 0, e = value.size(); i != e; ++i)
491     if (auto str = dyn_cast<StringAttr>(value[i]))
492       if (!str.getValue().empty())
493         setNameFn(getResult(i), str.getValue());
494 }
495 
496 //===----------------------------------------------------------------------===//
497 // CustomResultsNameOp
498 //===----------------------------------------------------------------------===//
499 
500 void CustomResultsNameOp::getAsmResultNames(
501     function_ref<void(Value, StringRef)> setNameFn) {
502   ArrayAttr value = getNames();
503   for (size_t i = 0, e = value.size(); i != e; ++i)
504     if (auto str = dyn_cast<StringAttr>(value[i]))
505       if (!str.empty())
506         setNameFn(getResult(i), str.getValue());
507 }
508 
509 //===----------------------------------------------------------------------===//
510 // ResultNameFromTypeOp
511 //===----------------------------------------------------------------------===//
512 
513 void ResultNameFromTypeOp::getAsmResultNames(
514     function_ref<void(Value, StringRef)> setNameFn) {
515   auto result = getResult();
516   auto setResultNameFn = [&](::llvm::StringRef name) {
517     setNameFn(result, name);
518   };
519   auto opAsmTypeInterface =
520       ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType());
521   opAsmTypeInterface.getAsmName(setResultNameFn);
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // BlockArgumentNameFromTypeOp
526 //===----------------------------------------------------------------------===//
527 
528 void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames(
529     ::mlir::Region &region, ::mlir::OpAsmSetValueNameFn setNameFn) {
530   for (auto &block : region) {
531     for (auto arg : block.getArguments()) {
532       if (auto opAsmTypeInterface =
533               ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) {
534         auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); };
535         opAsmTypeInterface.getAsmName(setArgNameFn);
536       }
537     }
538   }
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // ResultTypeWithTraitOp
543 //===----------------------------------------------------------------------===//
544 
545 LogicalResult ResultTypeWithTraitOp::verify() {
546   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
547     return success();
548   return emitError("result type should have trait 'TestTypeTrait'");
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // AttrWithTraitOp
553 //===----------------------------------------------------------------------===//
554 
555 LogicalResult AttrWithTraitOp::verify() {
556   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
557     return success();
558   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
559 }
560 
561 //===----------------------------------------------------------------------===//
562 // RegionIfOp
563 //===----------------------------------------------------------------------===//
564 
565 void RegionIfOp::print(OpAsmPrinter &p) {
566   p << " ";
567   p.printOperands(getOperands());
568   p << ": " << getOperandTypes();
569   p.printArrowTypeList(getResultTypes());
570   p << " then ";
571   p.printRegion(getThenRegion(),
572                 /*printEntryBlockArgs=*/true,
573                 /*printBlockTerminators=*/true);
574   p << " else ";
575   p.printRegion(getElseRegion(),
576                 /*printEntryBlockArgs=*/true,
577                 /*printBlockTerminators=*/true);
578   p << " join ";
579   p.printRegion(getJoinRegion(),
580                 /*printEntryBlockArgs=*/true,
581                 /*printBlockTerminators=*/true);
582 }
583 
584 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
585   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
586   SmallVector<Type, 2> operandTypes;
587 
588   result.regions.reserve(3);
589   Region *thenRegion = result.addRegion();
590   Region *elseRegion = result.addRegion();
591   Region *joinRegion = result.addRegion();
592 
593   // Parse operand, type and arrow type lists.
594   if (parser.parseOperandList(operandInfos) ||
595       parser.parseColonTypeList(operandTypes) ||
596       parser.parseArrowTypeList(result.types))
597     return failure();
598 
599   // Parse all attached regions.
600   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
601       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
602       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
603     return failure();
604 
605   return parser.resolveOperands(operandInfos, operandTypes,
606                                 parser.getCurrentLocation(), result.operands);
607 }
608 
609 OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
610   assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
611          "invalid region index");
612   return getOperands();
613 }
614 
615 void RegionIfOp::getSuccessorRegions(
616     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
617   // We always branch to the join region.
618   if (!point.isParent()) {
619     if (point != getJoinRegion())
620       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
621     else
622       regions.push_back(RegionSuccessor(getResults()));
623     return;
624   }
625 
626   // The then and else regions are the entry regions of this op.
627   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
628   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
629 }
630 
631 void RegionIfOp::getRegionInvocationBounds(
632     ArrayRef<Attribute> operands,
633     SmallVectorImpl<InvocationBounds> &invocationBounds) {
634   // Each region is invoked at most once.
635   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // AnyCondOp
640 //===----------------------------------------------------------------------===//
641 
642 void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
643                                     SmallVectorImpl<RegionSuccessor> &regions) {
644   // The parent op branches into the only region, and the region branches back
645   // to the parent op.
646   if (point.isParent())
647     regions.emplace_back(&getRegion());
648   else
649     regions.emplace_back(getResults());
650 }
651 
652 void AnyCondOp::getRegionInvocationBounds(
653     ArrayRef<Attribute> operands,
654     SmallVectorImpl<InvocationBounds> &invocationBounds) {
655   invocationBounds.emplace_back(1, 1);
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // SingleBlockImplicitTerminatorOp
660 //===----------------------------------------------------------------------===//
661 
662 /// Testing the correctness of some traits.
663 static_assert(
664     llvm::is_detected<OpTrait::has_implicit_terminator_t,
665                       SingleBlockImplicitTerminatorOp>::value,
666     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
667 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
668                   SingleBlockImplicitTerminatorOp>::value,
669               "hasSingleBlockImplicitTerminator does not match "
670               "SingleBlockImplicitTerminatorOp");
671 
672 //===----------------------------------------------------------------------===//
673 // SingleNoTerminatorCustomAsmOp
674 //===----------------------------------------------------------------------===//
675 
676 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
677                                                  OperationState &state) {
678   Region *body = state.addRegion();
679   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
680     return failure();
681   return success();
682 }
683 
684 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
685   printer.printRegion(
686       getRegion(), /*printEntryBlockArgs=*/false,
687       // This op has a single block without terminators. But explicitly mark
688       // as not printing block terminators for testing.
689       /*printBlockTerminators=*/false);
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // TestVerifiersOp
694 //===----------------------------------------------------------------------===//
695 
696 LogicalResult TestVerifiersOp::verify() {
697   if (!getRegion().hasOneBlock())
698     return emitOpError("`hasOneBlock` trait hasn't been verified");
699 
700   Operation *definingOp = getInput().getDefiningOp();
701   if (definingOp && failed(mlir::verify(definingOp)))
702     return emitOpError("operand hasn't been verified");
703 
704   // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
705   // loop.
706   mlir::emitRemark(getLoc(), "success run of verifier");
707 
708   return success();
709 }
710 
711 LogicalResult TestVerifiersOp::verifyRegions() {
712   if (!getRegion().hasOneBlock())
713     return emitOpError("`hasOneBlock` trait hasn't been verified");
714 
715   for (Block &block : getRegion())
716     for (Operation &op : block)
717       if (failed(mlir::verify(&op)))
718         return emitOpError("nested op hasn't been verified");
719 
720   // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier
721   // loop.
722   mlir::emitRemark(getLoc(), "success run of region verifier");
723 
724   return success();
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // Test InferIntRangeInterface
729 //===----------------------------------------------------------------------===//
730 
731 //===----------------------------------------------------------------------===//
732 // TestWithBoundsOp
733 
734 void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
735                                          SetIntRangeFn setResultRanges) {
736   setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
737 }
738 
739 //===----------------------------------------------------------------------===//
740 // TestWithBoundsRegionOp
741 
742 ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
743                                           OperationState &result) {
744   if (parser.parseOptionalAttrDict(result.attributes))
745     return failure();
746 
747   // Parse the input argument
748   OpAsmParser::Argument argInfo;
749   if (failed(parser.parseArgument(argInfo, true)))
750     return failure();
751 
752   // Parse the body region, and reuse the operand info as the argument info.
753   Region *body = result.addRegion();
754   return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
755 }
756 
757 void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
758   p.printOptionalAttrDict((*this)->getAttrs());
759   p << ' ';
760   p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
761                         /*omitType=*/false);
762   p << ' ';
763   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
764 }
765 
766 void TestWithBoundsRegionOp::inferResultRanges(
767     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
768   Value arg = getRegion().getArgument(0);
769   setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // TestIncrementOp
774 
775 void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
776                                         SetIntRangeFn setResultRanges) {
777   const ConstantIntRanges &range = argRanges[0];
778   APInt one(range.umin().getBitWidth(), 1);
779   setResultRanges(getResult(),
780                   {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
781                    range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // TestReflectBoundsOp
786 
787 void TestReflectBoundsOp::inferResultRanges(
788     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
789   const ConstantIntRanges &range = argRanges[0];
790   MLIRContext *ctx = getContext();
791   Builder b(ctx);
792   Type sIntTy, uIntTy;
793   // For plain `IntegerType`s, we can derive the appropriate signed and unsigned
794   // Types for the Attributes.
795   Type type = getElementTypeOrSelf(getType());
796   if (auto intTy = llvm::dyn_cast<IntegerType>(type)) {
797     unsigned bitwidth = intTy.getWidth();
798     sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true);
799     uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false);
800   } else
801     sIntTy = uIntTy = type;
802 
803   setUminAttr(b.getIntegerAttr(uIntTy, range.umin()));
804   setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax()));
805   setSminAttr(b.getIntegerAttr(sIntTy, range.smin()));
806   setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax()));
807   setResultRanges(getResult(), range);
808 }
809 
810 //===----------------------------------------------------------------------===//
811 // ConversionFuncOp
812 //===----------------------------------------------------------------------===//
813 
814 ParseResult ConversionFuncOp::parse(OpAsmParser &parser,
815                                     OperationState &result) {
816   auto buildFuncType =
817       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
818          function_interface_impl::VariadicFlag,
819          std::string &) { return builder.getFunctionType(argTypes, results); };
820 
821   return function_interface_impl::parseFunctionOp(
822       parser, result, /*allowVariadic=*/false,
823       getFunctionTypeAttrName(result.name), buildFuncType,
824       getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
825 }
826 
827 void ConversionFuncOp::print(OpAsmPrinter &p) {
828   function_interface_impl::printFunctionOp(
829       p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
830       getArgAttrsAttrName(), getResAttrsAttrName());
831 }
832 
833 //===----------------------------------------------------------------------===//
834 // ReifyBoundOp
835 //===----------------------------------------------------------------------===//
836 
837 mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
838   if (getType() == "EQ")
839     return mlir::presburger::BoundType::EQ;
840   if (getType() == "LB")
841     return mlir::presburger::BoundType::LB;
842   if (getType() == "UB")
843     return mlir::presburger::BoundType::UB;
844   llvm_unreachable("invalid bound type");
845 }
846 
847 LogicalResult ReifyBoundOp::verify() {
848   if (isa<ShapedType>(getVar().getType())) {
849     if (!getDim().has_value())
850       return emitOpError("expected 'dim' attribute for shaped type variable");
851   } else if (getVar().getType().isIndex()) {
852     if (getDim().has_value())
853       return emitOpError("unexpected 'dim' attribute for index variable");
854   } else {
855     return emitOpError("expected index-typed variable or shape type variable");
856   }
857   if (getConstant() && getScalable())
858     return emitOpError("'scalable' and 'constant' are mutually exlusive");
859   if (getScalable() != getVscaleMin().has_value())
860     return emitOpError("expected 'vscale_min' if and only if 'scalable'");
861   if (getScalable() != getVscaleMax().has_value())
862     return emitOpError("expected 'vscale_min' if and only if 'scalable'");
863   return success();
864 }
865 
866 ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
867   if (getDim().has_value())
868     return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
869   return ValueBoundsConstraintSet::Variable(getVar());
870 }
871 
872 //===----------------------------------------------------------------------===//
873 // CompareOp
874 //===----------------------------------------------------------------------===//
875 
876 ValueBoundsConstraintSet::ComparisonOperator
877 CompareOp::getComparisonOperator() {
878   if (getCmp() == "EQ")
879     return ValueBoundsConstraintSet::ComparisonOperator::EQ;
880   if (getCmp() == "LT")
881     return ValueBoundsConstraintSet::ComparisonOperator::LT;
882   if (getCmp() == "LE")
883     return ValueBoundsConstraintSet::ComparisonOperator::LE;
884   if (getCmp() == "GT")
885     return ValueBoundsConstraintSet::ComparisonOperator::GT;
886   if (getCmp() == "GE")
887     return ValueBoundsConstraintSet::ComparisonOperator::GE;
888   llvm_unreachable("invalid comparison operator");
889 }
890 
891 mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
892   if (!getLhsMap())
893     return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
894   SmallVector<Value> mapOperands(
895       getVarOperands().slice(0, getLhsMap()->getNumInputs()));
896   return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
897 }
898 
899 mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
900   int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
901   if (!getRhsMap())
902     return ValueBoundsConstraintSet::Variable(
903         getVarOperands()[rhsOperandsBegin]);
904   SmallVector<Value> mapOperands(
905       getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
906   return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
907 }
908 
909 LogicalResult CompareOp::verify() {
910   if (getCompose() && (getLhsMap() || getRhsMap()))
911     return emitOpError(
912         "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
913   int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
914   expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
915   if (getVarOperands().size() != size_t(expectedNumOperands))
916     return emitOpError("expected ")
917            << expectedNumOperands << " operands, but got "
918            << getVarOperands().size();
919   return success();
920 }
921 
922 //===----------------------------------------------------------------------===//
923 // TestOpInPlaceSelfFold
924 //===----------------------------------------------------------------------===//
925 
926 OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) {
927   if (!getFolded()) {
928     // The folder adds the "folded" if not present.
929     setFolded(true);
930     return getResult();
931   }
932   return {};
933 }
934 
935 //===----------------------------------------------------------------------===//
936 // TestOpFoldWithFoldAdaptor
937 //===----------------------------------------------------------------------===//
938 
939 OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) {
940   int64_t sum = 0;
941   if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()))
942     sum += value.getValue().getSExtValue();
943 
944   for (Attribute attr : adaptor.getVariadic())
945     if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
946       sum += 2 * value.getValue().getSExtValue();
947 
948   for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar())
949     for (Attribute attr : attrs)
950       if (auto value = dyn_cast_or_null<IntegerAttr>(attr))
951         sum += 3 * value.getValue().getSExtValue();
952 
953   sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end());
954 
955   return IntegerAttr::get(getType(), sum);
956 }
957 
958 //===----------------------------------------------------------------------===//
959 // OpWithInferTypeAdaptorInterfaceOp
960 //===----------------------------------------------------------------------===//
961 
962 LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes(
963     MLIRContext *, std::optional<Location> location,
964     OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor,
965     SmallVectorImpl<Type> &inferredReturnTypes) {
966   if (adaptor.getX().getType() != adaptor.getY().getType()) {
967     return emitOptionalError(location, "operand type mismatch ",
968                              adaptor.getX().getType(), " vs ",
969                              adaptor.getY().getType());
970   }
971   inferredReturnTypes.assign({adaptor.getX().getType()});
972   return success();
973 }
974 
975 //===----------------------------------------------------------------------===//
976 // OpWithRefineTypeInterfaceOp
977 //===----------------------------------------------------------------------===//
978 
979 // TODO: We should be able to only define either inferReturnType or
980 // refineReturnType, currently only refineReturnType can be omitted.
981 LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
982     MLIRContext *context, std::optional<Location> location, ValueRange operands,
983     DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
984     SmallVectorImpl<Type> &returnTypes) {
985   returnTypes.clear();
986   return OpWithRefineTypeInterfaceOp::refineReturnTypes(
987       context, location, operands, attributes, properties, regions,
988       returnTypes);
989 }
990 
991 LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
992     MLIRContext *, std::optional<Location> location, ValueRange operands,
993     DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
994     SmallVectorImpl<Type> &returnTypes) {
995   if (operands[0].getType() != operands[1].getType()) {
996     return emitOptionalError(location, "operand type mismatch ",
997                              operands[0].getType(), " vs ",
998                              operands[1].getType());
999   }
1000   // TODO: Add helper to make this more concise to write.
1001   if (returnTypes.empty())
1002     returnTypes.resize(1, nullptr);
1003   if (returnTypes[0] && returnTypes[0] != operands[0].getType())
1004     return emitOptionalError(location,
1005                              "required first operand and result to match");
1006   returnTypes[0] = operands[0].getType();
1007   return success();
1008 }
1009 
1010 //===----------------------------------------------------------------------===//
1011 // OpWithShapedTypeInferTypeAdaptorInterfaceOp
1012 //===----------------------------------------------------------------------===//
1013 
1014 LogicalResult
1015 OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents(
1016     MLIRContext *context, std::optional<Location> location,
1017     OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor,
1018     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1019   // Create return type consisting of the last element of the first operand.
1020   auto operandType = adaptor.getOperand1().getType();
1021   auto sval = dyn_cast<ShapedType>(operandType);
1022   if (!sval)
1023     return emitOptionalError(location, "only shaped type operands allowed");
1024   int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic;
1025   auto type = IntegerType::get(context, 17);
1026 
1027   Attribute encoding;
1028   if (auto rankedTy = dyn_cast<RankedTensorType>(sval))
1029     encoding = rankedTy.getEncoding();
1030   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding));
1031   return success();
1032 }
1033 
1034 LogicalResult
1035 OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes(
1036     OpBuilder &builder, ValueRange operands,
1037     llvm::SmallVectorImpl<Value> &shapes) {
1038   shapes = SmallVector<Value, 1>{
1039       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1040   return success();
1041 }
1042 
1043 //===----------------------------------------------------------------------===//
1044 // TestOpWithPropertiesAndInferredType
1045 //===----------------------------------------------------------------------===//
1046 
1047 LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes(
1048     MLIRContext *context, std::optional<Location>, ValueRange operands,
1049     DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
1050     SmallVectorImpl<Type> &inferredReturnTypes) {
1051 
1052   Adaptor adaptor(operands, attributes, properties, regions);
1053   inferredReturnTypes.push_back(IntegerType::get(
1054       context, adaptor.getLhs() + adaptor.getProperties().rhs));
1055   return success();
1056 }
1057 
1058 //===----------------------------------------------------------------------===//
1059 // LoopBlockOp
1060 //===----------------------------------------------------------------------===//
1061 
1062 void LoopBlockOp::getSuccessorRegions(
1063     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1064   regions.emplace_back(&getBody(), getBody().getArguments());
1065   if (point.isParent())
1066     return;
1067 
1068   regions.emplace_back((*this)->getResults());
1069 }
1070 
1071 OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1072   assert(point == getBody());
1073   return MutableOperandRange(getInitMutable());
1074 }
1075 
1076 //===----------------------------------------------------------------------===//
1077 // LoopBlockTerminatorOp
1078 //===----------------------------------------------------------------------===//
1079 
1080 MutableOperandRange
1081 LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
1082   if (point.isParent())
1083     return getExitArgMutable();
1084   return getNextIterArgMutable();
1085 }
1086 
1087 //===----------------------------------------------------------------------===//
1088 // SwitchWithNoBreakOp
1089 //===----------------------------------------------------------------------===//
1090 
1091 void TestNoTerminatorOp::getSuccessorRegions(
1092     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {}
1093 
1094 //===----------------------------------------------------------------------===//
1095 // Test InferIntRangeInterface
1096 //===----------------------------------------------------------------------===//
1097 
1098 OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) {
1099   // Just a simple fold for testing purposes that reads an operands constant
1100   // value and returns it.
1101   if (!attributes.empty())
1102     return attributes.front();
1103   return nullptr;
1104 }
1105 
1106 //===----------------------------------------------------------------------===//
1107 // Tensor/Buffer Ops
1108 //===----------------------------------------------------------------------===//
1109 
1110 void ReadBufferOp::getEffects(
1111     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1112         &effects) {
1113   // The buffer operand is read.
1114   effects.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(),
1115                        SideEffects::DefaultResource::get());
1116   // The buffer contents are dumped.
1117   effects.emplace_back(MemoryEffects::Write::get(),
1118                        SideEffects::DefaultResource::get());
1119 }
1120 
1121 //===----------------------------------------------------------------------===//
1122 // Test Dataflow
1123 //===----------------------------------------------------------------------===//
1124 
1125 //===----------------------------------------------------------------------===//
1126 // TestCallAndStoreOp
1127 
1128 CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() {
1129   return getCallee();
1130 }
1131 
1132 void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1133   setCalleeAttr(cast<SymbolRefAttr>(callee));
1134 }
1135 
1136 Operation::operand_range TestCallAndStoreOp::getArgOperands() {
1137   return getCalleeOperands();
1138 }
1139 
1140 MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
1141   return getCalleeOperandsMutable();
1142 }
1143 
1144 //===----------------------------------------------------------------------===//
1145 // TestCallOnDeviceOp
1146 
1147 CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() {
1148   return getCallee();
1149 }
1150 
1151 void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1152   setCalleeAttr(cast<SymbolRefAttr>(callee));
1153 }
1154 
1155 Operation::operand_range TestCallOnDeviceOp::getArgOperands() {
1156   return getForwardedOperands();
1157 }
1158 
1159 MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() {
1160   return getForwardedOperandsMutable();
1161 }
1162 
1163 //===----------------------------------------------------------------------===//
1164 // TestStoreWithARegion
1165 
1166 void TestStoreWithARegion::getSuccessorRegions(
1167     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1168   if (point.isParent())
1169     regions.emplace_back(&getBody(), getBody().front().getArguments());
1170   else
1171     regions.emplace_back();
1172 }
1173 
1174 //===----------------------------------------------------------------------===//
1175 // TestStoreWithALoopRegion
1176 
1177 void TestStoreWithALoopRegion::getSuccessorRegions(
1178     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1179   // Both the operation itself and the region may be branching into the body or
1180   // back into the operation itself. It is possible for the operation not to
1181   // enter the body.
1182   regions.emplace_back(
1183       RegionSuccessor(&getBody(), getBody().front().getArguments()));
1184   regions.emplace_back();
1185 }
1186 
1187 //===----------------------------------------------------------------------===//
1188 // TestVersionedOpA
1189 //===----------------------------------------------------------------------===//
1190 
1191 LogicalResult
1192 TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader,
1193                                  mlir::OperationState &state) {
1194   auto &prop = state.getOrAddProperties<Properties>();
1195   if (mlir::failed(reader.readAttribute(prop.dims)))
1196     return mlir::failure();
1197 
1198   // Check if we have a version. If not, assume we are parsing the current
1199   // version.
1200   auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1201   if (succeeded(maybeVersion)) {
1202     // If version is less than 2.0, there is no additional attribute to parse.
1203     // We can materialize missing properties post parsing before verification.
1204     const auto *version =
1205         reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1206     if ((version->major_ < 2)) {
1207       return success();
1208     }
1209   }
1210 
1211   if (mlir::failed(reader.readAttribute(prop.modifier)))
1212     return mlir::failure();
1213   return mlir::success();
1214 }
1215 
1216 void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) {
1217   auto &prop = getProperties();
1218   writer.writeAttribute(prop.dims);
1219 
1220   auto maybeVersion = writer.getDialectVersion<test::TestDialect>();
1221   if (succeeded(maybeVersion)) {
1222     // If version is less than 2.0, there is no additional attribute to write.
1223     const auto *version =
1224         reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1225     if ((version->major_ < 2)) {
1226       llvm::outs() << "downgrading op properties...\n";
1227       return;
1228     }
1229   }
1230   writer.writeAttribute(prop.modifier);
1231 }
1232 
1233 //===----------------------------------------------------------------------===//
1234 // TestOpWithVersionedProperties
1235 //===----------------------------------------------------------------------===//
1236 
1237 llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode(
1238     mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) {
1239   uint64_t value1, value2 = 0;
1240   if (failed(reader.readVarInt(value1)))
1241     return failure();
1242 
1243   // Check if we have a version. If not, assume we are parsing the current
1244   // version.
1245   auto maybeVersion = reader.getDialectVersion<test::TestDialect>();
1246   bool needToParseAnotherInt = true;
1247   if (succeeded(maybeVersion)) {
1248     // If version is less than 2.0, there is no additional attribute to parse.
1249     // We can materialize missing properties post parsing before verification.
1250     const auto *version =
1251         reinterpret_cast<const TestDialectVersion *>(*maybeVersion);
1252     if ((version->major_ < 2))
1253       needToParseAnotherInt = false;
1254   }
1255   if (needToParseAnotherInt && failed(reader.readVarInt(value2)))
1256     return failure();
1257 
1258   prop.value1 = value1;
1259   prop.value2 = value2;
1260   return success();
1261 }
1262 
1263 void TestOpWithVersionedProperties::writeToMlirBytecode(
1264     mlir::DialectBytecodeWriter &writer,
1265     const test::VersionedProperties &prop) {
1266   writer.writeVarInt(prop.value1);
1267   writer.writeVarInt(prop.value2);
1268 }
1269 
1270 //===----------------------------------------------------------------------===//
1271 // TestMultiSlotAlloca
1272 //===----------------------------------------------------------------------===//
1273 
1274 llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
1275   SmallVector<MemorySlot> slots;
1276   for (Value result : getResults()) {
1277     slots.push_back(MemorySlot{
1278         result, cast<MemRefType>(result.getType()).getElementType()});
1279   }
1280   return slots;
1281 }
1282 
1283 Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
1284                                            OpBuilder &builder) {
1285   return builder.create<TestOpConstant>(getLoc(), slot.elemType,
1286                                         builder.getI32IntegerAttr(42));
1287 }
1288 
1289 void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
1290                                               BlockArgument argument,
1291                                               OpBuilder &builder) {
1292   // Not relevant for testing.
1293 }
1294 
1295 /// Creates a new TestMultiSlotAlloca operation, just without the `slot`.
1296 static std::optional<TestMultiSlotAlloca>
1297 createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder,
1298                                 TestMultiSlotAlloca oldOp) {
1299 
1300   if (oldOp.getNumResults() == 1) {
1301     oldOp.erase();
1302     return std::nullopt;
1303   }
1304 
1305   SmallVector<Type> newTypes;
1306   SmallVector<Value> remainingValues;
1307 
1308   for (Value oldResult : oldOp.getResults()) {
1309     if (oldResult == slot.ptr)
1310       continue;
1311     remainingValues.push_back(oldResult);
1312     newTypes.push_back(oldResult.getType());
1313   }
1314 
1315   OpBuilder::InsertionGuard guard(builder);
1316   builder.setInsertionPoint(oldOp);
1317   auto replacement =
1318       builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes);
1319   for (auto [oldResult, newResult] :
1320        llvm::zip_equal(remainingValues, replacement.getResults()))
1321     oldResult.replaceAllUsesWith(newResult);
1322 
1323   oldOp.erase();
1324   return replacement;
1325 }
1326 
1327 std::optional<PromotableAllocationOpInterface>
1328 TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
1329                                              Value defaultValue,
1330                                              OpBuilder &builder) {
1331   if (defaultValue && defaultValue.use_empty())
1332     defaultValue.getDefiningOp()->erase();
1333   return createNewMultiAllocaWithoutSlot(slot, builder, *this);
1334 }
1335 
1336 SmallVector<DestructurableMemorySlot>
1337 TestMultiSlotAlloca::getDestructurableSlots() {
1338   SmallVector<DestructurableMemorySlot> slots;
1339   for (Value result : getResults()) {
1340     auto memrefType = cast<MemRefType>(result.getType());
1341     auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType);
1342     if (!destructurable)
1343       continue;
1344 
1345     std::optional<DenseMap<Attribute, Type>> destructuredType =
1346         destructurable.getSubelementIndexMap();
1347     if (!destructuredType)
1348       continue;
1349     slots.emplace_back(
1350         DestructurableMemorySlot{{result, memrefType}, *destructuredType});
1351   }
1352   return slots;
1353 }
1354 
1355 DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure(
1356     const DestructurableMemorySlot &slot,
1357     const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
1358     SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
1359   OpBuilder::InsertionGuard guard(builder);
1360   builder.setInsertionPointAfter(*this);
1361 
1362   DenseMap<Attribute, MemorySlot> slotMap;
1363 
1364   for (Attribute usedIndex : usedIndices) {
1365     Type elemType = slot.subelementTypes.lookup(usedIndex);
1366     MemRefType elemPtr = MemRefType::get({}, elemType);
1367     auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr);
1368     newAllocators.push_back(subAlloca);
1369     slotMap.try_emplace<MemorySlot>(usedIndex,
1370                                     {subAlloca.getResult(0), elemType});
1371   }
1372 
1373   return slotMap;
1374 }
1375 
1376 std::optional<DestructurableAllocationOpInterface>
1377 TestMultiSlotAlloca::handleDestructuringComplete(
1378     const DestructurableMemorySlot &slot, OpBuilder &builder) {
1379   return createNewMultiAllocaWithoutSlot(slot, builder, *this);
1380 }
1381