1 //===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "TestOps.h" 11 #include "mlir/Dialect/Tensor/IR/Tensor.h" 12 #include "mlir/IR/Verifier.h" 13 #include "mlir/Interfaces/FunctionImplementation.h" 14 #include "mlir/Interfaces/MemorySlotInterfaces.h" 15 16 using namespace mlir; 17 using namespace test; 18 19 //===----------------------------------------------------------------------===// 20 // TestBranchOp 21 //===----------------------------------------------------------------------===// 22 23 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { 24 assert(index == 0 && "invalid successor index"); 25 return SuccessorOperands(getTargetOperandsMutable()); 26 } 27 28 //===----------------------------------------------------------------------===// 29 // TestProducingBranchOp 30 //===----------------------------------------------------------------------===// 31 32 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { 33 assert(index <= 1 && "invalid successor index"); 34 if (index == 1) 35 return SuccessorOperands(getFirstOperandsMutable()); 36 return SuccessorOperands(getSecondOperandsMutable()); 37 } 38 39 //===----------------------------------------------------------------------===// 40 // TestInternalBranchOp 41 //===----------------------------------------------------------------------===// 42 43 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { 44 assert(index <= 1 && "invalid successor index"); 45 if (index == 0) 46 return SuccessorOperands(0, getSuccessOperandsMutable()); 47 return SuccessorOperands(1, getErrorOperandsMutable()); 48 } 49 50 //===----------------------------------------------------------------------===// 51 // TestCallOp 52 //===----------------------------------------------------------------------===// 53 54 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 55 // Check that the callee attribute was specified. 56 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 57 if (!fnAttr) 58 return emitOpError("requires a 'callee' symbol reference attribute"); 59 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr)) 60 return emitOpError() << "'" << fnAttr.getValue() 61 << "' does not reference a valid function"; 62 return success(); 63 } 64 65 //===----------------------------------------------------------------------===// 66 // FoldToCallOp 67 //===----------------------------------------------------------------------===// 68 69 namespace { 70 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> { 71 using OpRewritePattern<FoldToCallOp>::OpRewritePattern; 72 73 LogicalResult matchAndRewrite(FoldToCallOp op, 74 PatternRewriter &rewriter) const override { 75 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(), 76 op.getCalleeAttr(), ValueRange()); 77 return success(); 78 } 79 }; 80 } // namespace 81 82 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, 83 MLIRContext *context) { 84 results.add<FoldToCallOpPattern>(context); 85 } 86 87 //===----------------------------------------------------------------------===// 88 // IsolatedRegionOp - test parsing passthrough operands 89 //===----------------------------------------------------------------------===// 90 91 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, 92 OperationState &result) { 93 // Parse the input operand. 94 OpAsmParser::Argument argInfo; 95 argInfo.type = parser.getBuilder().getIndexType(); 96 if (parser.parseOperand(argInfo.ssaName) || 97 parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) 98 return failure(); 99 100 // Parse the body region, and reuse the operand info as the argument info. 101 Region *body = result.addRegion(); 102 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); 103 } 104 105 void IsolatedRegionOp::print(OpAsmPrinter &p) { 106 p << ' '; 107 p.printOperand(getOperand()); 108 p.shadowRegionArgs(getRegion(), getOperand()); 109 p << ' '; 110 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 111 } 112 113 //===----------------------------------------------------------------------===// 114 // SSACFGRegionOp 115 //===----------------------------------------------------------------------===// 116 117 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { 118 return RegionKind::SSACFG; 119 } 120 121 //===----------------------------------------------------------------------===// 122 // GraphRegionOp 123 //===----------------------------------------------------------------------===// 124 125 RegionKind GraphRegionOp::getRegionKind(unsigned index) { 126 return RegionKind::Graph; 127 } 128 129 //===----------------------------------------------------------------------===// 130 // IsolatedGraphRegionOp 131 //===----------------------------------------------------------------------===// 132 133 RegionKind IsolatedGraphRegionOp::getRegionKind(unsigned index) { 134 return RegionKind::Graph; 135 } 136 137 //===----------------------------------------------------------------------===// 138 // AffineScopeOp 139 //===----------------------------------------------------------------------===// 140 141 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { 142 // Parse the body region, and reuse the operand info as the argument info. 143 Region *body = result.addRegion(); 144 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); 145 } 146 147 void AffineScopeOp::print(OpAsmPrinter &p) { 148 p << " "; 149 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 150 } 151 152 //===----------------------------------------------------------------------===// 153 // TestRemoveOpWithInnerOps 154 //===----------------------------------------------------------------------===// 155 156 namespace { 157 struct TestRemoveOpWithInnerOps 158 : public OpRewritePattern<TestOpWithRegionPattern> { 159 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern; 160 161 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } 162 163 LogicalResult matchAndRewrite(TestOpWithRegionPattern op, 164 PatternRewriter &rewriter) const override { 165 rewriter.eraseOp(op); 166 return success(); 167 } 168 }; 169 } // namespace 170 171 //===----------------------------------------------------------------------===// 172 // TestOpWithRegionPattern 173 //===----------------------------------------------------------------------===// 174 175 void TestOpWithRegionPattern::getCanonicalizationPatterns( 176 RewritePatternSet &results, MLIRContext *context) { 177 results.add<TestRemoveOpWithInnerOps>(context); 178 } 179 180 //===----------------------------------------------------------------------===// 181 // TestOpWithRegionFold 182 //===----------------------------------------------------------------------===// 183 184 OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { 185 return getOperand(); 186 } 187 188 //===----------------------------------------------------------------------===// 189 // TestOpConstant 190 //===----------------------------------------------------------------------===// 191 192 OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } 193 194 //===----------------------------------------------------------------------===// 195 // TestOpWithVariadicResultsAndFolder 196 //===----------------------------------------------------------------------===// 197 198 LogicalResult TestOpWithVariadicResultsAndFolder::fold( 199 FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { 200 for (Value input : this->getOperands()) { 201 results.push_back(input); 202 } 203 return success(); 204 } 205 206 //===----------------------------------------------------------------------===// 207 // TestOpInPlaceFold 208 //===----------------------------------------------------------------------===// 209 210 OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { 211 // Exercise the fact that an operation created with createOrFold should be 212 // allowed to access its parent block. 213 assert(getOperation()->getBlock() && 214 "expected that operation is not unlinked"); 215 216 if (adaptor.getOp() && !getProperties().attr) { 217 // The folder adds "attr" if not present. 218 getProperties().attr = dyn_cast_or_null<IntegerAttr>(adaptor.getOp()); 219 return getResult(); 220 } 221 return {}; 222 } 223 224 //===----------------------------------------------------------------------===// 225 // OpWithInferTypeInterfaceOp 226 //===----------------------------------------------------------------------===// 227 228 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( 229 MLIRContext *, std::optional<Location> location, ValueRange operands, 230 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, 231 SmallVectorImpl<Type> &inferredReturnTypes) { 232 if (operands[0].getType() != operands[1].getType()) { 233 return emitOptionalError(location, "operand type mismatch ", 234 operands[0].getType(), " vs ", 235 operands[1].getType()); 236 } 237 inferredReturnTypes.assign({operands[0].getType()}); 238 return success(); 239 } 240 241 //===----------------------------------------------------------------------===// 242 // OpWithShapedTypeInferTypeInterfaceOp 243 //===----------------------------------------------------------------------===// 244 245 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( 246 MLIRContext *context, std::optional<Location> location, 247 ValueShapeRange operands, DictionaryAttr attributes, 248 OpaqueProperties properties, RegionRange regions, 249 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 250 // Create return type consisting of the last element of the first operand. 251 auto operandType = operands.front().getType(); 252 auto sval = dyn_cast<ShapedType>(operandType); 253 if (!sval) 254 return emitOptionalError(location, "only shaped type operands allowed"); 255 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; 256 auto type = IntegerType::get(context, 17); 257 258 Attribute encoding; 259 if (auto rankedTy = dyn_cast<RankedTensorType>(sval)) 260 encoding = rankedTy.getEncoding(); 261 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); 262 return success(); 263 } 264 265 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( 266 OpBuilder &builder, ValueRange operands, 267 llvm::SmallVectorImpl<Value> &shapes) { 268 shapes = SmallVector<Value, 1>{ 269 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 270 return success(); 271 } 272 273 //===----------------------------------------------------------------------===// 274 // OpWithResultShapeInterfaceOp 275 //===----------------------------------------------------------------------===// 276 277 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( 278 OpBuilder &builder, ValueRange operands, 279 llvm::SmallVectorImpl<Value> &shapes) { 280 Location loc = getLoc(); 281 shapes.reserve(operands.size()); 282 for (Value operand : llvm::reverse(operands)) { 283 auto rank = cast<RankedTensorType>(operand.getType()).getRank(); 284 auto currShape = llvm::to_vector<4>( 285 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value { 286 return builder.createOrFold<tensor::DimOp>(loc, operand, dim); 287 })); 288 shapes.push_back(builder.create<tensor::FromElementsOp>( 289 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), 290 currShape)); 291 } 292 return success(); 293 } 294 295 //===----------------------------------------------------------------------===// 296 // OpWithResultShapePerDimInterfaceOp 297 //===----------------------------------------------------------------------===// 298 299 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( 300 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { 301 Location loc = getLoc(); 302 shapes.reserve(getNumOperands()); 303 for (Value operand : llvm::reverse(getOperands())) { 304 auto tensorType = cast<RankedTensorType>(operand.getType()); 305 auto currShape = llvm::to_vector<4>(llvm::map_range( 306 llvm::seq<int64_t>(0, tensorType.getRank()), 307 [&](int64_t dim) -> OpFoldResult { 308 return tensorType.isDynamicDim(dim) 309 ? static_cast<OpFoldResult>( 310 builder.createOrFold<tensor::DimOp>(loc, operand, 311 dim)) 312 : static_cast<OpFoldResult>( 313 builder.getIndexAttr(tensorType.getDimSize(dim))); 314 })); 315 shapes.emplace_back(std::move(currShape)); 316 } 317 return success(); 318 } 319 320 //===----------------------------------------------------------------------===// 321 // SideEffectOp 322 //===----------------------------------------------------------------------===// 323 324 namespace { 325 /// A test resource for side effects. 326 struct TestResource : public SideEffects::Resource::Base<TestResource> { 327 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) 328 329 StringRef getName() final { return "<Test>"; } 330 }; 331 } // namespace 332 333 void SideEffectOp::getEffects( 334 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 335 // Check for an effects attribute on the op instance. 336 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 337 if (!effectsAttr) 338 return; 339 340 for (Attribute element : effectsAttr) { 341 DictionaryAttr effectElement = cast<DictionaryAttr>(element); 342 343 // Get the specific memory effect. 344 MemoryEffects::Effect *effect = 345 StringSwitch<MemoryEffects::Effect *>( 346 cast<StringAttr>(effectElement.get("effect")).getValue()) 347 .Case("allocate", MemoryEffects::Allocate::get()) 348 .Case("free", MemoryEffects::Free::get()) 349 .Case("read", MemoryEffects::Read::get()) 350 .Case("write", MemoryEffects::Write::get()); 351 352 // Check for a non-default resource to use. 353 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 354 if (effectElement.get("test_resource")) 355 resource = TestResource::get(); 356 357 // Check for a result to affect. 358 if (effectElement.get("on_result")) 359 effects.emplace_back(effect, getOperation()->getOpResults()[0], resource); 360 else if (Attribute ref = effectElement.get("on_reference")) 361 effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource); 362 else 363 effects.emplace_back(effect, resource); 364 } 365 } 366 367 void SideEffectOp::getEffects( 368 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 369 testSideEffectOpGetEffect(getOperation(), effects); 370 } 371 372 void SideEffectWithRegionOp::getEffects( 373 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 374 // Check for an effects attribute on the op instance. 375 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects"); 376 if (!effectsAttr) 377 return; 378 379 for (Attribute element : effectsAttr) { 380 DictionaryAttr effectElement = cast<DictionaryAttr>(element); 381 382 // Get the specific memory effect. 383 MemoryEffects::Effect *effect = 384 StringSwitch<MemoryEffects::Effect *>( 385 cast<StringAttr>(effectElement.get("effect")).getValue()) 386 .Case("allocate", MemoryEffects::Allocate::get()) 387 .Case("free", MemoryEffects::Free::get()) 388 .Case("read", MemoryEffects::Read::get()) 389 .Case("write", MemoryEffects::Write::get()); 390 391 // Check for a non-default resource to use. 392 SideEffects::Resource *resource = SideEffects::DefaultResource::get(); 393 if (effectElement.get("test_resource")) 394 resource = TestResource::get(); 395 396 // Check for a result to affect. 397 if (effectElement.get("on_result")) 398 effects.emplace_back(effect, getOperation()->getOpResults()[0], resource); 399 else if (effectElement.get("on_operand")) 400 effects.emplace_back(effect, &getOperation()->getOpOperands()[0], 401 resource); 402 else if (effectElement.get("on_argument")) 403 effects.emplace_back(effect, getOperation()->getRegion(0).getArgument(0), 404 resource); 405 else if (Attribute ref = effectElement.get("on_reference")) 406 effects.emplace_back(effect, cast<SymbolRefAttr>(ref), resource); 407 else 408 effects.emplace_back(effect, resource); 409 } 410 } 411 412 void SideEffectWithRegionOp::getEffects( 413 SmallVectorImpl<TestEffects::EffectInstance> &effects) { 414 testSideEffectOpGetEffect(getOperation(), effects); 415 } 416 417 //===----------------------------------------------------------------------===// 418 // StringAttrPrettyNameOp 419 //===----------------------------------------------------------------------===// 420 421 // This op has fancy handling of its SSA result name. 422 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, 423 OperationState &result) { 424 // Add the result types. 425 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) 426 result.addTypes(parser.getBuilder().getIntegerType(32)); 427 428 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 429 return failure(); 430 431 // If the attribute dictionary contains no 'names' attribute, infer it from 432 // the SSA name (if specified). 433 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { 434 return attr.getName() == "names"; 435 }); 436 437 // If there was no name specified, check to see if there was a useful name 438 // specified in the asm file. 439 if (hadNames || parser.getNumResults() == 0) 440 return success(); 441 442 SmallVector<StringRef, 4> names; 443 auto *context = result.getContext(); 444 445 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { 446 auto resultName = parser.getResultName(i); 447 StringRef nameStr; 448 if (!resultName.first.empty() && !isdigit(resultName.first[0])) 449 nameStr = resultName.first; 450 451 names.push_back(nameStr); 452 } 453 454 auto namesAttr = parser.getBuilder().getStrArrayAttr(names); 455 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); 456 return success(); 457 } 458 459 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { 460 // Note that we only need to print the "name" attribute if the asmprinter 461 // result name disagrees with it. This can happen in strange cases, e.g. 462 // when there are conflicts. 463 bool namesDisagree = getNames().size() != getNumResults(); 464 465 SmallString<32> resultNameStr; 466 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { 467 resultNameStr.clear(); 468 llvm::raw_svector_ostream tmpStream(resultNameStr); 469 p.printOperand(getResult(i), tmpStream); 470 471 auto expectedName = dyn_cast<StringAttr>(getNames()[i]); 472 if (!expectedName || 473 tmpStream.str().drop_front() != expectedName.getValue()) { 474 namesDisagree = true; 475 } 476 } 477 478 if (namesDisagree) 479 p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); 480 else 481 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); 482 } 483 484 // We set the SSA name in the asm syntax to the contents of the name 485 // attribute. 486 void StringAttrPrettyNameOp::getAsmResultNames( 487 function_ref<void(Value, StringRef)> setNameFn) { 488 489 auto value = getNames(); 490 for (size_t i = 0, e = value.size(); i != e; ++i) 491 if (auto str = dyn_cast<StringAttr>(value[i])) 492 if (!str.getValue().empty()) 493 setNameFn(getResult(i), str.getValue()); 494 } 495 496 //===----------------------------------------------------------------------===// 497 // CustomResultsNameOp 498 //===----------------------------------------------------------------------===// 499 500 void CustomResultsNameOp::getAsmResultNames( 501 function_ref<void(Value, StringRef)> setNameFn) { 502 ArrayAttr value = getNames(); 503 for (size_t i = 0, e = value.size(); i != e; ++i) 504 if (auto str = dyn_cast<StringAttr>(value[i])) 505 if (!str.empty()) 506 setNameFn(getResult(i), str.getValue()); 507 } 508 509 //===----------------------------------------------------------------------===// 510 // ResultNameFromTypeOp 511 //===----------------------------------------------------------------------===// 512 513 void ResultNameFromTypeOp::getAsmResultNames( 514 function_ref<void(Value, StringRef)> setNameFn) { 515 auto result = getResult(); 516 auto setResultNameFn = [&](::llvm::StringRef name) { 517 setNameFn(result, name); 518 }; 519 auto opAsmTypeInterface = 520 ::mlir::cast<::mlir::OpAsmTypeInterface>(result.getType()); 521 opAsmTypeInterface.getAsmName(setResultNameFn); 522 } 523 524 //===----------------------------------------------------------------------===// 525 // BlockArgumentNameFromTypeOp 526 //===----------------------------------------------------------------------===// 527 528 void BlockArgumentNameFromTypeOp::getAsmBlockArgumentNames( 529 ::mlir::Region ®ion, ::mlir::OpAsmSetValueNameFn setNameFn) { 530 for (auto &block : region) { 531 for (auto arg : block.getArguments()) { 532 if (auto opAsmTypeInterface = 533 ::mlir::dyn_cast<::mlir::OpAsmTypeInterface>(arg.getType())) { 534 auto setArgNameFn = [&](StringRef name) { setNameFn(arg, name); }; 535 opAsmTypeInterface.getAsmName(setArgNameFn); 536 } 537 } 538 } 539 } 540 541 //===----------------------------------------------------------------------===// 542 // ResultTypeWithTraitOp 543 //===----------------------------------------------------------------------===// 544 545 LogicalResult ResultTypeWithTraitOp::verify() { 546 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>()) 547 return success(); 548 return emitError("result type should have trait 'TestTypeTrait'"); 549 } 550 551 //===----------------------------------------------------------------------===// 552 // AttrWithTraitOp 553 //===----------------------------------------------------------------------===// 554 555 LogicalResult AttrWithTraitOp::verify() { 556 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>()) 557 return success(); 558 return emitError("'attr' attribute should have trait 'TestAttrTrait'"); 559 } 560 561 //===----------------------------------------------------------------------===// 562 // RegionIfOp 563 //===----------------------------------------------------------------------===// 564 565 void RegionIfOp::print(OpAsmPrinter &p) { 566 p << " "; 567 p.printOperands(getOperands()); 568 p << ": " << getOperandTypes(); 569 p.printArrowTypeList(getResultTypes()); 570 p << " then "; 571 p.printRegion(getThenRegion(), 572 /*printEntryBlockArgs=*/true, 573 /*printBlockTerminators=*/true); 574 p << " else "; 575 p.printRegion(getElseRegion(), 576 /*printEntryBlockArgs=*/true, 577 /*printBlockTerminators=*/true); 578 p << " join "; 579 p.printRegion(getJoinRegion(), 580 /*printEntryBlockArgs=*/true, 581 /*printBlockTerminators=*/true); 582 } 583 584 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { 585 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos; 586 SmallVector<Type, 2> operandTypes; 587 588 result.regions.reserve(3); 589 Region *thenRegion = result.addRegion(); 590 Region *elseRegion = result.addRegion(); 591 Region *joinRegion = result.addRegion(); 592 593 // Parse operand, type and arrow type lists. 594 if (parser.parseOperandList(operandInfos) || 595 parser.parseColonTypeList(operandTypes) || 596 parser.parseArrowTypeList(result.types)) 597 return failure(); 598 599 // Parse all attached regions. 600 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || 601 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || 602 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) 603 return failure(); 604 605 return parser.resolveOperands(operandInfos, operandTypes, 606 parser.getCurrentLocation(), result.operands); 607 } 608 609 OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { 610 assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && 611 "invalid region index"); 612 return getOperands(); 613 } 614 615 void RegionIfOp::getSuccessorRegions( 616 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 617 // We always branch to the join region. 618 if (!point.isParent()) { 619 if (point != getJoinRegion()) 620 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); 621 else 622 regions.push_back(RegionSuccessor(getResults())); 623 return; 624 } 625 626 // The then and else regions are the entry regions of this op. 627 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); 628 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); 629 } 630 631 void RegionIfOp::getRegionInvocationBounds( 632 ArrayRef<Attribute> operands, 633 SmallVectorImpl<InvocationBounds> &invocationBounds) { 634 // Each region is invoked at most once. 635 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); 636 } 637 638 //===----------------------------------------------------------------------===// 639 // AnyCondOp 640 //===----------------------------------------------------------------------===// 641 642 void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, 643 SmallVectorImpl<RegionSuccessor> ®ions) { 644 // The parent op branches into the only region, and the region branches back 645 // to the parent op. 646 if (point.isParent()) 647 regions.emplace_back(&getRegion()); 648 else 649 regions.emplace_back(getResults()); 650 } 651 652 void AnyCondOp::getRegionInvocationBounds( 653 ArrayRef<Attribute> operands, 654 SmallVectorImpl<InvocationBounds> &invocationBounds) { 655 invocationBounds.emplace_back(1, 1); 656 } 657 658 //===----------------------------------------------------------------------===// 659 // SingleBlockImplicitTerminatorOp 660 //===----------------------------------------------------------------------===// 661 662 /// Testing the correctness of some traits. 663 static_assert( 664 llvm::is_detected<OpTrait::has_implicit_terminator_t, 665 SingleBlockImplicitTerminatorOp>::value, 666 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); 667 static_assert(OpTrait::hasSingleBlockImplicitTerminator< 668 SingleBlockImplicitTerminatorOp>::value, 669 "hasSingleBlockImplicitTerminator does not match " 670 "SingleBlockImplicitTerminatorOp"); 671 672 //===----------------------------------------------------------------------===// 673 // SingleNoTerminatorCustomAsmOp 674 //===----------------------------------------------------------------------===// 675 676 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, 677 OperationState &state) { 678 Region *body = state.addRegion(); 679 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) 680 return failure(); 681 return success(); 682 } 683 684 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { 685 printer.printRegion( 686 getRegion(), /*printEntryBlockArgs=*/false, 687 // This op has a single block without terminators. But explicitly mark 688 // as not printing block terminators for testing. 689 /*printBlockTerminators=*/false); 690 } 691 692 //===----------------------------------------------------------------------===// 693 // TestVerifiersOp 694 //===----------------------------------------------------------------------===// 695 696 LogicalResult TestVerifiersOp::verify() { 697 if (!getRegion().hasOneBlock()) 698 return emitOpError("`hasOneBlock` trait hasn't been verified"); 699 700 Operation *definingOp = getInput().getDefiningOp(); 701 if (definingOp && failed(mlir::verify(definingOp))) 702 return emitOpError("operand hasn't been verified"); 703 704 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier 705 // loop. 706 mlir::emitRemark(getLoc(), "success run of verifier"); 707 708 return success(); 709 } 710 711 LogicalResult TestVerifiersOp::verifyRegions() { 712 if (!getRegion().hasOneBlock()) 713 return emitOpError("`hasOneBlock` trait hasn't been verified"); 714 715 for (Block &block : getRegion()) 716 for (Operation &op : block) 717 if (failed(mlir::verify(&op))) 718 return emitOpError("nested op hasn't been verified"); 719 720 // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier 721 // loop. 722 mlir::emitRemark(getLoc(), "success run of region verifier"); 723 724 return success(); 725 } 726 727 //===----------------------------------------------------------------------===// 728 // Test InferIntRangeInterface 729 //===----------------------------------------------------------------------===// 730 731 //===----------------------------------------------------------------------===// 732 // TestWithBoundsOp 733 734 void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 735 SetIntRangeFn setResultRanges) { 736 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); 737 } 738 739 //===----------------------------------------------------------------------===// 740 // TestWithBoundsRegionOp 741 742 ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, 743 OperationState &result) { 744 if (parser.parseOptionalAttrDict(result.attributes)) 745 return failure(); 746 747 // Parse the input argument 748 OpAsmParser::Argument argInfo; 749 if (failed(parser.parseArgument(argInfo, true))) 750 return failure(); 751 752 // Parse the body region, and reuse the operand info as the argument info. 753 Region *body = result.addRegion(); 754 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); 755 } 756 757 void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { 758 p.printOptionalAttrDict((*this)->getAttrs()); 759 p << ' '; 760 p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, 761 /*omitType=*/false); 762 p << ' '; 763 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); 764 } 765 766 void TestWithBoundsRegionOp::inferResultRanges( 767 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { 768 Value arg = getRegion().getArgument(0); 769 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); 770 } 771 772 //===----------------------------------------------------------------------===// 773 // TestIncrementOp 774 775 void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, 776 SetIntRangeFn setResultRanges) { 777 const ConstantIntRanges &range = argRanges[0]; 778 APInt one(range.umin().getBitWidth(), 1); 779 setResultRanges(getResult(), 780 {range.umin().uadd_sat(one), range.umax().uadd_sat(one), 781 range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); 782 } 783 784 //===----------------------------------------------------------------------===// 785 // TestReflectBoundsOp 786 787 void TestReflectBoundsOp::inferResultRanges( 788 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) { 789 const ConstantIntRanges &range = argRanges[0]; 790 MLIRContext *ctx = getContext(); 791 Builder b(ctx); 792 Type sIntTy, uIntTy; 793 // For plain `IntegerType`s, we can derive the appropriate signed and unsigned 794 // Types for the Attributes. 795 Type type = getElementTypeOrSelf(getType()); 796 if (auto intTy = llvm::dyn_cast<IntegerType>(type)) { 797 unsigned bitwidth = intTy.getWidth(); 798 sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true); 799 uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false); 800 } else 801 sIntTy = uIntTy = type; 802 803 setUminAttr(b.getIntegerAttr(uIntTy, range.umin())); 804 setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax())); 805 setSminAttr(b.getIntegerAttr(sIntTy, range.smin())); 806 setSmaxAttr(b.getIntegerAttr(sIntTy, range.smax())); 807 setResultRanges(getResult(), range); 808 } 809 810 //===----------------------------------------------------------------------===// 811 // ConversionFuncOp 812 //===----------------------------------------------------------------------===// 813 814 ParseResult ConversionFuncOp::parse(OpAsmParser &parser, 815 OperationState &result) { 816 auto buildFuncType = 817 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 818 function_interface_impl::VariadicFlag, 819 std::string &) { return builder.getFunctionType(argTypes, results); }; 820 821 return function_interface_impl::parseFunctionOp( 822 parser, result, /*allowVariadic=*/false, 823 getFunctionTypeAttrName(result.name), buildFuncType, 824 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); 825 } 826 827 void ConversionFuncOp::print(OpAsmPrinter &p) { 828 function_interface_impl::printFunctionOp( 829 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), 830 getArgAttrsAttrName(), getResAttrsAttrName()); 831 } 832 833 //===----------------------------------------------------------------------===// 834 // ReifyBoundOp 835 //===----------------------------------------------------------------------===// 836 837 mlir::presburger::BoundType ReifyBoundOp::getBoundType() { 838 if (getType() == "EQ") 839 return mlir::presburger::BoundType::EQ; 840 if (getType() == "LB") 841 return mlir::presburger::BoundType::LB; 842 if (getType() == "UB") 843 return mlir::presburger::BoundType::UB; 844 llvm_unreachable("invalid bound type"); 845 } 846 847 LogicalResult ReifyBoundOp::verify() { 848 if (isa<ShapedType>(getVar().getType())) { 849 if (!getDim().has_value()) 850 return emitOpError("expected 'dim' attribute for shaped type variable"); 851 } else if (getVar().getType().isIndex()) { 852 if (getDim().has_value()) 853 return emitOpError("unexpected 'dim' attribute for index variable"); 854 } else { 855 return emitOpError("expected index-typed variable or shape type variable"); 856 } 857 if (getConstant() && getScalable()) 858 return emitOpError("'scalable' and 'constant' are mutually exlusive"); 859 if (getScalable() != getVscaleMin().has_value()) 860 return emitOpError("expected 'vscale_min' if and only if 'scalable'"); 861 if (getScalable() != getVscaleMax().has_value()) 862 return emitOpError("expected 'vscale_min' if and only if 'scalable'"); 863 return success(); 864 } 865 866 ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() { 867 if (getDim().has_value()) 868 return ValueBoundsConstraintSet::Variable(getVar(), *getDim()); 869 return ValueBoundsConstraintSet::Variable(getVar()); 870 } 871 872 //===----------------------------------------------------------------------===// 873 // CompareOp 874 //===----------------------------------------------------------------------===// 875 876 ValueBoundsConstraintSet::ComparisonOperator 877 CompareOp::getComparisonOperator() { 878 if (getCmp() == "EQ") 879 return ValueBoundsConstraintSet::ComparisonOperator::EQ; 880 if (getCmp() == "LT") 881 return ValueBoundsConstraintSet::ComparisonOperator::LT; 882 if (getCmp() == "LE") 883 return ValueBoundsConstraintSet::ComparisonOperator::LE; 884 if (getCmp() == "GT") 885 return ValueBoundsConstraintSet::ComparisonOperator::GT; 886 if (getCmp() == "GE") 887 return ValueBoundsConstraintSet::ComparisonOperator::GE; 888 llvm_unreachable("invalid comparison operator"); 889 } 890 891 mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() { 892 if (!getLhsMap()) 893 return ValueBoundsConstraintSet::Variable(getVarOperands()[0]); 894 SmallVector<Value> mapOperands( 895 getVarOperands().slice(0, getLhsMap()->getNumInputs())); 896 return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands); 897 } 898 899 mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() { 900 int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1; 901 if (!getRhsMap()) 902 return ValueBoundsConstraintSet::Variable( 903 getVarOperands()[rhsOperandsBegin]); 904 SmallVector<Value> mapOperands( 905 getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs())); 906 return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands); 907 } 908 909 LogicalResult CompareOp::verify() { 910 if (getCompose() && (getLhsMap() || getRhsMap())) 911 return emitOpError( 912 "'compose' not supported when 'lhs_map' or 'rhs_map' is present"); 913 int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1; 914 expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1; 915 if (getVarOperands().size() != size_t(expectedNumOperands)) 916 return emitOpError("expected ") 917 << expectedNumOperands << " operands, but got " 918 << getVarOperands().size(); 919 return success(); 920 } 921 922 //===----------------------------------------------------------------------===// 923 // TestOpInPlaceSelfFold 924 //===----------------------------------------------------------------------===// 925 926 OpFoldResult TestOpInPlaceSelfFold::fold(FoldAdaptor adaptor) { 927 if (!getFolded()) { 928 // The folder adds the "folded" if not present. 929 setFolded(true); 930 return getResult(); 931 } 932 return {}; 933 } 934 935 //===----------------------------------------------------------------------===// 936 // TestOpFoldWithFoldAdaptor 937 //===----------------------------------------------------------------------===// 938 939 OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { 940 int64_t sum = 0; 941 if (auto value = dyn_cast_or_null<IntegerAttr>(adaptor.getOp())) 942 sum += value.getValue().getSExtValue(); 943 944 for (Attribute attr : adaptor.getVariadic()) 945 if (auto value = dyn_cast_or_null<IntegerAttr>(attr)) 946 sum += 2 * value.getValue().getSExtValue(); 947 948 for (ArrayRef<Attribute> attrs : adaptor.getVarOfVar()) 949 for (Attribute attr : attrs) 950 if (auto value = dyn_cast_or_null<IntegerAttr>(attr)) 951 sum += 3 * value.getValue().getSExtValue(); 952 953 sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end()); 954 955 return IntegerAttr::get(getType(), sum); 956 } 957 958 //===----------------------------------------------------------------------===// 959 // OpWithInferTypeAdaptorInterfaceOp 960 //===----------------------------------------------------------------------===// 961 962 LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes( 963 MLIRContext *, std::optional<Location> location, 964 OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor, 965 SmallVectorImpl<Type> &inferredReturnTypes) { 966 if (adaptor.getX().getType() != adaptor.getY().getType()) { 967 return emitOptionalError(location, "operand type mismatch ", 968 adaptor.getX().getType(), " vs ", 969 adaptor.getY().getType()); 970 } 971 inferredReturnTypes.assign({adaptor.getX().getType()}); 972 return success(); 973 } 974 975 //===----------------------------------------------------------------------===// 976 // OpWithRefineTypeInterfaceOp 977 //===----------------------------------------------------------------------===// 978 979 // TODO: We should be able to only define either inferReturnType or 980 // refineReturnType, currently only refineReturnType can be omitted. 981 LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( 982 MLIRContext *context, std::optional<Location> location, ValueRange operands, 983 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, 984 SmallVectorImpl<Type> &returnTypes) { 985 returnTypes.clear(); 986 return OpWithRefineTypeInterfaceOp::refineReturnTypes( 987 context, location, operands, attributes, properties, regions, 988 returnTypes); 989 } 990 991 LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( 992 MLIRContext *, std::optional<Location> location, ValueRange operands, 993 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, 994 SmallVectorImpl<Type> &returnTypes) { 995 if (operands[0].getType() != operands[1].getType()) { 996 return emitOptionalError(location, "operand type mismatch ", 997 operands[0].getType(), " vs ", 998 operands[1].getType()); 999 } 1000 // TODO: Add helper to make this more concise to write. 1001 if (returnTypes.empty()) 1002 returnTypes.resize(1, nullptr); 1003 if (returnTypes[0] && returnTypes[0] != operands[0].getType()) 1004 return emitOptionalError(location, 1005 "required first operand and result to match"); 1006 returnTypes[0] = operands[0].getType(); 1007 return success(); 1008 } 1009 1010 //===----------------------------------------------------------------------===// 1011 // OpWithShapedTypeInferTypeAdaptorInterfaceOp 1012 //===----------------------------------------------------------------------===// 1013 1014 LogicalResult 1015 OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents( 1016 MLIRContext *context, std::optional<Location> location, 1017 OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor, 1018 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { 1019 // Create return type consisting of the last element of the first operand. 1020 auto operandType = adaptor.getOperand1().getType(); 1021 auto sval = dyn_cast<ShapedType>(operandType); 1022 if (!sval) 1023 return emitOptionalError(location, "only shaped type operands allowed"); 1024 int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; 1025 auto type = IntegerType::get(context, 17); 1026 1027 Attribute encoding; 1028 if (auto rankedTy = dyn_cast<RankedTensorType>(sval)) 1029 encoding = rankedTy.getEncoding(); 1030 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); 1031 return success(); 1032 } 1033 1034 LogicalResult 1035 OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes( 1036 OpBuilder &builder, ValueRange operands, 1037 llvm::SmallVectorImpl<Value> &shapes) { 1038 shapes = SmallVector<Value, 1>{ 1039 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)}; 1040 return success(); 1041 } 1042 1043 //===----------------------------------------------------------------------===// 1044 // TestOpWithPropertiesAndInferredType 1045 //===----------------------------------------------------------------------===// 1046 1047 LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes( 1048 MLIRContext *context, std::optional<Location>, ValueRange operands, 1049 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, 1050 SmallVectorImpl<Type> &inferredReturnTypes) { 1051 1052 Adaptor adaptor(operands, attributes, properties, regions); 1053 inferredReturnTypes.push_back(IntegerType::get( 1054 context, adaptor.getLhs() + adaptor.getProperties().rhs)); 1055 return success(); 1056 } 1057 1058 //===----------------------------------------------------------------------===// 1059 // LoopBlockOp 1060 //===----------------------------------------------------------------------===// 1061 1062 void LoopBlockOp::getSuccessorRegions( 1063 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1064 regions.emplace_back(&getBody(), getBody().getArguments()); 1065 if (point.isParent()) 1066 return; 1067 1068 regions.emplace_back((*this)->getResults()); 1069 } 1070 1071 OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { 1072 assert(point == getBody()); 1073 return MutableOperandRange(getInitMutable()); 1074 } 1075 1076 //===----------------------------------------------------------------------===// 1077 // LoopBlockTerminatorOp 1078 //===----------------------------------------------------------------------===// 1079 1080 MutableOperandRange 1081 LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { 1082 if (point.isParent()) 1083 return getExitArgMutable(); 1084 return getNextIterArgMutable(); 1085 } 1086 1087 //===----------------------------------------------------------------------===// 1088 // SwitchWithNoBreakOp 1089 //===----------------------------------------------------------------------===// 1090 1091 void TestNoTerminatorOp::getSuccessorRegions( 1092 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {} 1093 1094 //===----------------------------------------------------------------------===// 1095 // Test InferIntRangeInterface 1096 //===----------------------------------------------------------------------===// 1097 1098 OpFoldResult ManualCppOpWithFold::fold(ArrayRef<Attribute> attributes) { 1099 // Just a simple fold for testing purposes that reads an operands constant 1100 // value and returns it. 1101 if (!attributes.empty()) 1102 return attributes.front(); 1103 return nullptr; 1104 } 1105 1106 //===----------------------------------------------------------------------===// 1107 // Tensor/Buffer Ops 1108 //===----------------------------------------------------------------------===// 1109 1110 void ReadBufferOp::getEffects( 1111 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> 1112 &effects) { 1113 // The buffer operand is read. 1114 effects.emplace_back(MemoryEffects::Read::get(), &getBufferMutable(), 1115 SideEffects::DefaultResource::get()); 1116 // The buffer contents are dumped. 1117 effects.emplace_back(MemoryEffects::Write::get(), 1118 SideEffects::DefaultResource::get()); 1119 } 1120 1121 //===----------------------------------------------------------------------===// 1122 // Test Dataflow 1123 //===----------------------------------------------------------------------===// 1124 1125 //===----------------------------------------------------------------------===// 1126 // TestCallAndStoreOp 1127 1128 CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() { 1129 return getCallee(); 1130 } 1131 1132 void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) { 1133 setCalleeAttr(cast<SymbolRefAttr>(callee)); 1134 } 1135 1136 Operation::operand_range TestCallAndStoreOp::getArgOperands() { 1137 return getCalleeOperands(); 1138 } 1139 1140 MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() { 1141 return getCalleeOperandsMutable(); 1142 } 1143 1144 //===----------------------------------------------------------------------===// 1145 // TestCallOnDeviceOp 1146 1147 CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { 1148 return getCallee(); 1149 } 1150 1151 void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { 1152 setCalleeAttr(cast<SymbolRefAttr>(callee)); 1153 } 1154 1155 Operation::operand_range TestCallOnDeviceOp::getArgOperands() { 1156 return getForwardedOperands(); 1157 } 1158 1159 MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { 1160 return getForwardedOperandsMutable(); 1161 } 1162 1163 //===----------------------------------------------------------------------===// 1164 // TestStoreWithARegion 1165 1166 void TestStoreWithARegion::getSuccessorRegions( 1167 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1168 if (point.isParent()) 1169 regions.emplace_back(&getBody(), getBody().front().getArguments()); 1170 else 1171 regions.emplace_back(); 1172 } 1173 1174 //===----------------------------------------------------------------------===// 1175 // TestStoreWithALoopRegion 1176 1177 void TestStoreWithALoopRegion::getSuccessorRegions( 1178 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { 1179 // Both the operation itself and the region may be branching into the body or 1180 // back into the operation itself. It is possible for the operation not to 1181 // enter the body. 1182 regions.emplace_back( 1183 RegionSuccessor(&getBody(), getBody().front().getArguments())); 1184 regions.emplace_back(); 1185 } 1186 1187 //===----------------------------------------------------------------------===// 1188 // TestVersionedOpA 1189 //===----------------------------------------------------------------------===// 1190 1191 LogicalResult 1192 TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader, 1193 mlir::OperationState &state) { 1194 auto &prop = state.getOrAddProperties<Properties>(); 1195 if (mlir::failed(reader.readAttribute(prop.dims))) 1196 return mlir::failure(); 1197 1198 // Check if we have a version. If not, assume we are parsing the current 1199 // version. 1200 auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); 1201 if (succeeded(maybeVersion)) { 1202 // If version is less than 2.0, there is no additional attribute to parse. 1203 // We can materialize missing properties post parsing before verification. 1204 const auto *version = 1205 reinterpret_cast<const TestDialectVersion *>(*maybeVersion); 1206 if ((version->major_ < 2)) { 1207 return success(); 1208 } 1209 } 1210 1211 if (mlir::failed(reader.readAttribute(prop.modifier))) 1212 return mlir::failure(); 1213 return mlir::success(); 1214 } 1215 1216 void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) { 1217 auto &prop = getProperties(); 1218 writer.writeAttribute(prop.dims); 1219 1220 auto maybeVersion = writer.getDialectVersion<test::TestDialect>(); 1221 if (succeeded(maybeVersion)) { 1222 // If version is less than 2.0, there is no additional attribute to write. 1223 const auto *version = 1224 reinterpret_cast<const TestDialectVersion *>(*maybeVersion); 1225 if ((version->major_ < 2)) { 1226 llvm::outs() << "downgrading op properties...\n"; 1227 return; 1228 } 1229 } 1230 writer.writeAttribute(prop.modifier); 1231 } 1232 1233 //===----------------------------------------------------------------------===// 1234 // TestOpWithVersionedProperties 1235 //===----------------------------------------------------------------------===// 1236 1237 llvm::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode( 1238 mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) { 1239 uint64_t value1, value2 = 0; 1240 if (failed(reader.readVarInt(value1))) 1241 return failure(); 1242 1243 // Check if we have a version. If not, assume we are parsing the current 1244 // version. 1245 auto maybeVersion = reader.getDialectVersion<test::TestDialect>(); 1246 bool needToParseAnotherInt = true; 1247 if (succeeded(maybeVersion)) { 1248 // If version is less than 2.0, there is no additional attribute to parse. 1249 // We can materialize missing properties post parsing before verification. 1250 const auto *version = 1251 reinterpret_cast<const TestDialectVersion *>(*maybeVersion); 1252 if ((version->major_ < 2)) 1253 needToParseAnotherInt = false; 1254 } 1255 if (needToParseAnotherInt && failed(reader.readVarInt(value2))) 1256 return failure(); 1257 1258 prop.value1 = value1; 1259 prop.value2 = value2; 1260 return success(); 1261 } 1262 1263 void TestOpWithVersionedProperties::writeToMlirBytecode( 1264 mlir::DialectBytecodeWriter &writer, 1265 const test::VersionedProperties &prop) { 1266 writer.writeVarInt(prop.value1); 1267 writer.writeVarInt(prop.value2); 1268 } 1269 1270 //===----------------------------------------------------------------------===// 1271 // TestMultiSlotAlloca 1272 //===----------------------------------------------------------------------===// 1273 1274 llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() { 1275 SmallVector<MemorySlot> slots; 1276 for (Value result : getResults()) { 1277 slots.push_back(MemorySlot{ 1278 result, cast<MemRefType>(result.getType()).getElementType()}); 1279 } 1280 return slots; 1281 } 1282 1283 Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot, 1284 OpBuilder &builder) { 1285 return builder.create<TestOpConstant>(getLoc(), slot.elemType, 1286 builder.getI32IntegerAttr(42)); 1287 } 1288 1289 void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot, 1290 BlockArgument argument, 1291 OpBuilder &builder) { 1292 // Not relevant for testing. 1293 } 1294 1295 /// Creates a new TestMultiSlotAlloca operation, just without the `slot`. 1296 static std::optional<TestMultiSlotAlloca> 1297 createNewMultiAllocaWithoutSlot(const MemorySlot &slot, OpBuilder &builder, 1298 TestMultiSlotAlloca oldOp) { 1299 1300 if (oldOp.getNumResults() == 1) { 1301 oldOp.erase(); 1302 return std::nullopt; 1303 } 1304 1305 SmallVector<Type> newTypes; 1306 SmallVector<Value> remainingValues; 1307 1308 for (Value oldResult : oldOp.getResults()) { 1309 if (oldResult == slot.ptr) 1310 continue; 1311 remainingValues.push_back(oldResult); 1312 newTypes.push_back(oldResult.getType()); 1313 } 1314 1315 OpBuilder::InsertionGuard guard(builder); 1316 builder.setInsertionPoint(oldOp); 1317 auto replacement = 1318 builder.create<TestMultiSlotAlloca>(oldOp->getLoc(), newTypes); 1319 for (auto [oldResult, newResult] : 1320 llvm::zip_equal(remainingValues, replacement.getResults())) 1321 oldResult.replaceAllUsesWith(newResult); 1322 1323 oldOp.erase(); 1324 return replacement; 1325 } 1326 1327 std::optional<PromotableAllocationOpInterface> 1328 TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot, 1329 Value defaultValue, 1330 OpBuilder &builder) { 1331 if (defaultValue && defaultValue.use_empty()) 1332 defaultValue.getDefiningOp()->erase(); 1333 return createNewMultiAllocaWithoutSlot(slot, builder, *this); 1334 } 1335 1336 SmallVector<DestructurableMemorySlot> 1337 TestMultiSlotAlloca::getDestructurableSlots() { 1338 SmallVector<DestructurableMemorySlot> slots; 1339 for (Value result : getResults()) { 1340 auto memrefType = cast<MemRefType>(result.getType()); 1341 auto destructurable = dyn_cast<DestructurableTypeInterface>(memrefType); 1342 if (!destructurable) 1343 continue; 1344 1345 std::optional<DenseMap<Attribute, Type>> destructuredType = 1346 destructurable.getSubelementIndexMap(); 1347 if (!destructuredType) 1348 continue; 1349 slots.emplace_back( 1350 DestructurableMemorySlot{{result, memrefType}, *destructuredType}); 1351 } 1352 return slots; 1353 } 1354 1355 DenseMap<Attribute, MemorySlot> TestMultiSlotAlloca::destructure( 1356 const DestructurableMemorySlot &slot, 1357 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder, 1358 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) { 1359 OpBuilder::InsertionGuard guard(builder); 1360 builder.setInsertionPointAfter(*this); 1361 1362 DenseMap<Attribute, MemorySlot> slotMap; 1363 1364 for (Attribute usedIndex : usedIndices) { 1365 Type elemType = slot.subelementTypes.lookup(usedIndex); 1366 MemRefType elemPtr = MemRefType::get({}, elemType); 1367 auto subAlloca = builder.create<TestMultiSlotAlloca>(getLoc(), elemPtr); 1368 newAllocators.push_back(subAlloca); 1369 slotMap.try_emplace<MemorySlot>(usedIndex, 1370 {subAlloca.getResult(0), elemType}); 1371 } 1372 1373 return slotMap; 1374 } 1375 1376 std::optional<DestructurableAllocationOpInterface> 1377 TestMultiSlotAlloca::handleDestructuringComplete( 1378 const DestructurableMemorySlot &slot, OpBuilder &builder) { 1379 return createNewMultiAllocaWithoutSlot(slot, builder, *this); 1380 } 1381