1 //===- TestTransformDialectExtension.cpp ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines an extension of the MLIR Transform dialect for testing 10 // purposes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "TestTransformDialectExtension.h" 15 #include "TestTransformStateExtension.h" 16 #include "mlir/Dialect/PDL/IR/PDL.h" 17 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 18 #include "mlir/Dialect/Transform/IR/TransformOps.h" 19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 20 #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" 21 #include "mlir/IR/OpImplementation.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include "llvm/Support/Compiler.h" 26 #include "llvm/Support/raw_ostream.h" 27 28 using namespace mlir; 29 30 namespace { 31 /// Simple transform op defined outside of the dialect. Just emits a remark when 32 /// applied. This op is defined in C++ to test that C++ definitions also work 33 /// for op injection into the Transform dialect. 34 class TestTransformOp 35 : public Op<TestTransformOp, transform::TransformOpInterface::Trait, 36 MemoryEffectOpInterface::Trait> { 37 public: 38 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformOp) 39 40 using Op::Op; 41 42 static ArrayRef<StringRef> getAttributeNames() { return {}; } 43 44 static constexpr llvm::StringLiteral getOperationName() { 45 return llvm::StringLiteral("transform.test_transform_op"); 46 } 47 48 DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, 49 transform::TransformResults &results, 50 transform::TransformState &state) { 51 InFlightDiagnostic remark = emitRemark() << "applying transformation"; 52 if (Attribute message = getMessage()) 53 remark << " " << message; 54 55 return DiagnosedSilenceableFailure::success(); 56 } 57 58 Attribute getMessage() { 59 return getOperation()->getDiscardableAttr("message"); 60 } 61 62 static ParseResult parse(OpAsmParser &parser, OperationState &state) { 63 StringAttr message; 64 OptionalParseResult result = parser.parseOptionalAttribute(message); 65 if (!result.has_value()) 66 return success(); 67 68 if (result.value().succeeded()) 69 state.addAttribute("message", message); 70 return result.value(); 71 } 72 73 void print(OpAsmPrinter &printer) { 74 if (getMessage()) 75 printer << " " << getMessage(); 76 } 77 78 // No side effects. 79 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 80 }; 81 82 /// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait 83 /// in cases where it is attached to ops that do not comply with the trait 84 /// requirements. This op cannot be defined in ODS because ODS generates strict 85 /// verifiers that overalp with those in the trait and run earlier. 86 class TestTransformUnrestrictedOpNoInterface 87 : public Op<TestTransformUnrestrictedOpNoInterface, 88 transform::PossibleTopLevelTransformOpTrait, 89 transform::TransformOpInterface::Trait, 90 MemoryEffectOpInterface::Trait> { 91 public: 92 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 93 TestTransformUnrestrictedOpNoInterface) 94 95 using Op::Op; 96 97 static ArrayRef<StringRef> getAttributeNames() { return {}; } 98 99 static constexpr llvm::StringLiteral getOperationName() { 100 return llvm::StringLiteral( 101 "transform.test_transform_unrestricted_op_no_interface"); 102 } 103 104 DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, 105 transform::TransformResults &results, 106 transform::TransformState &state) { 107 return DiagnosedSilenceableFailure::success(); 108 } 109 110 // No side effects. 111 void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 112 }; 113 } // namespace 114 115 DiagnosedSilenceableFailure 116 mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( 117 transform::TransformRewriter &rewriter, 118 transform::TransformResults &results, transform::TransformState &state) { 119 if (getOperation()->getNumOperands() != 0) { 120 results.set(cast<OpResult>(getResult()), 121 {getOperation()->getOperand(0).getDefiningOp()}); 122 } else { 123 results.set(cast<OpResult>(getResult()), {getOperation()}); 124 } 125 return DiagnosedSilenceableFailure::success(); 126 } 127 128 void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects( 129 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 130 if (getOperand()) 131 transform::onlyReadsHandle(getOperandMutable(), effects); 132 transform::producesHandle(getOperation()->getOpResults(), effects); 133 } 134 135 DiagnosedSilenceableFailure 136 mlir::test::TestProduceValueHandleToSelfOperand::apply( 137 transform::TransformRewriter &rewriter, 138 transform::TransformResults &results, transform::TransformState &state) { 139 results.setValues(llvm::cast<OpResult>(getOut()), {getIn()}); 140 return DiagnosedSilenceableFailure::success(); 141 } 142 143 void mlir::test::TestProduceValueHandleToSelfOperand::getEffects( 144 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 145 transform::onlyReadsHandle(getInMutable(), effects); 146 transform::producesHandle(getOperation()->getOpResults(), effects); 147 transform::onlyReadsPayload(effects); 148 } 149 150 DiagnosedSilenceableFailure 151 mlir::test::TestProduceValueHandleToResult::applyToOne( 152 transform::TransformRewriter &rewriter, Operation *target, 153 transform::ApplyToEachResultList &results, 154 transform::TransformState &state) { 155 if (target->getNumResults() <= getNumber()) 156 return emitSilenceableError() << "payload has no result #" << getNumber(); 157 results.push_back(target->getResult(getNumber())); 158 return DiagnosedSilenceableFailure::success(); 159 } 160 161 void mlir::test::TestProduceValueHandleToResult::getEffects( 162 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 163 transform::onlyReadsHandle(getInMutable(), effects); 164 transform::producesHandle(getOperation()->getOpResults(), effects); 165 transform::onlyReadsPayload(effects); 166 } 167 168 DiagnosedSilenceableFailure 169 mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( 170 transform::TransformRewriter &rewriter, Operation *target, 171 transform::ApplyToEachResultList &results, 172 transform::TransformState &state) { 173 if (!target->getBlock()) 174 return emitSilenceableError() << "payload has no parent block"; 175 if (target->getBlock()->getNumArguments() <= getNumber()) 176 return emitSilenceableError() 177 << "parent of the payload has no argument #" << getNumber(); 178 results.push_back(target->getBlock()->getArgument(getNumber())); 179 return DiagnosedSilenceableFailure::success(); 180 } 181 182 void mlir::test::TestProduceValueHandleToArgumentOfParentBlock::getEffects( 183 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 184 transform::onlyReadsHandle(getInMutable(), effects); 185 transform::producesHandle(getOperation()->getOpResults(), effects); 186 transform::onlyReadsPayload(effects); 187 } 188 189 bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() { 190 return getAllowRepeatedHandles(); 191 } 192 193 DiagnosedSilenceableFailure 194 mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter, 195 transform::TransformResults &results, 196 transform::TransformState &state) { 197 return DiagnosedSilenceableFailure::success(); 198 } 199 200 void mlir::test::TestConsumeOperand::getEffects( 201 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 202 transform::consumesHandle(getOperation()->getOpOperands(), effects); 203 if (getSecondOperand()) 204 transform::consumesHandle(getSecondOperandMutable(), effects); 205 transform::modifiesPayload(effects); 206 } 207 208 DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( 209 transform::TransformRewriter &rewriter, 210 transform::TransformResults &results, transform::TransformState &state) { 211 auto payload = state.getPayloadOps(getOperand()); 212 assert(llvm::hasSingleElement(payload) && "expected a single target op"); 213 if ((*payload.begin())->getName().getStringRef() != getOpKind()) { 214 return emitSilenceableError() 215 << "op expected the operand to be associated a payload op of kind " 216 << getOpKind() << " got " 217 << (*payload.begin())->getName().getStringRef(); 218 } 219 220 emitRemark() << "succeeded"; 221 return DiagnosedSilenceableFailure::success(); 222 } 223 224 void mlir::test::TestConsumeOperandOfOpKindOrFail::getEffects( 225 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 226 transform::consumesHandle(getOperation()->getOpOperands(), effects); 227 transform::modifiesPayload(effects); 228 } 229 230 DiagnosedSilenceableFailure 231 mlir::test::TestSucceedIfOperandOfOpKind::matchOperation( 232 Operation *op, transform::TransformResults &results, 233 transform::TransformState &state) { 234 if (op->getName().getStringRef() != getOpKind()) { 235 return emitSilenceableError() 236 << "op expected the operand to be associated with a payload op of " 237 "kind " 238 << getOpKind() << " got " << op->getName().getStringRef(); 239 } 240 return DiagnosedSilenceableFailure::success(); 241 } 242 243 void mlir::test::TestSucceedIfOperandOfOpKind::getEffects( 244 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 245 transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); 246 transform::onlyReadsPayload(effects); 247 } 248 249 DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply( 250 transform::TransformRewriter &rewriter, 251 transform::TransformResults &results, transform::TransformState &state) { 252 state.addExtension<TestTransformStateExtension>(getMessageAttr()); 253 return DiagnosedSilenceableFailure::success(); 254 } 255 256 DiagnosedSilenceableFailure 257 mlir::test::TestCheckIfTestExtensionPresentOp::apply( 258 transform::TransformRewriter &rewriter, 259 transform::TransformResults &results, transform::TransformState &state) { 260 auto *extension = state.getExtension<TestTransformStateExtension>(); 261 if (!extension) { 262 emitRemark() << "extension absent"; 263 return DiagnosedSilenceableFailure::success(); 264 } 265 266 InFlightDiagnostic diag = emitRemark() 267 << "extension present, " << extension->getMessage(); 268 for (Operation *payload : state.getPayloadOps(getOperand())) { 269 diag.attachNote(payload->getLoc()) << "associated payload op"; 270 #ifndef NDEBUG 271 SmallVector<Value> handles; 272 assert(succeeded(state.getHandlesForPayloadOp(payload, handles))); 273 assert(llvm::is_contained(handles, getOperand()) && 274 "inconsistent mapping between transform IR handles and payload IR " 275 "operations"); 276 #endif // NDEBUG 277 } 278 279 return DiagnosedSilenceableFailure::success(); 280 } 281 282 void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects( 283 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 284 transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); 285 transform::onlyReadsPayload(effects); 286 } 287 288 DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( 289 transform::TransformRewriter &rewriter, 290 transform::TransformResults &results, transform::TransformState &state) { 291 auto *extension = state.getExtension<TestTransformStateExtension>(); 292 if (!extension) 293 return emitDefiniteFailure("TestTransformStateExtension missing"); 294 295 if (failed(extension->updateMapping( 296 *state.getPayloadOps(getOperand()).begin(), getOperation()))) 297 return DiagnosedSilenceableFailure::definiteFailure(); 298 if (getNumResults() > 0) 299 results.set(cast<OpResult>(getResult(0)), {getOperation()}); 300 return DiagnosedSilenceableFailure::success(); 301 } 302 303 void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects( 304 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 305 transform::onlyReadsHandle(getOperation()->getOpOperands(), effects); 306 transform::producesHandle(getOperation()->getOpResults(), effects); 307 transform::onlyReadsPayload(effects); 308 } 309 310 DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( 311 transform::TransformRewriter &rewriter, 312 transform::TransformResults &results, transform::TransformState &state) { 313 state.removeExtension<TestTransformStateExtension>(); 314 return DiagnosedSilenceableFailure::success(); 315 } 316 317 DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( 318 transform::TransformRewriter &rewriter, 319 transform::TransformResults &results, transform::TransformState &state) { 320 auto payloadOps = state.getPayloadOps(getTarget()); 321 auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); 322 results.set(llvm::cast<OpResult>(getResult()), reversedOps); 323 return DiagnosedSilenceableFailure::success(); 324 } 325 326 DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( 327 transform::TransformRewriter &rewriter, 328 transform::TransformResults &results, transform::TransformState &state) { 329 return DiagnosedSilenceableFailure::success(); 330 } 331 332 void mlir::test::TestTransformOpWithRegions::getEffects( 333 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 334 335 DiagnosedSilenceableFailure 336 mlir::test::TestBranchingTransformOpTerminator::apply( 337 transform::TransformRewriter &rewriter, 338 transform::TransformResults &results, transform::TransformState &state) { 339 return DiagnosedSilenceableFailure::success(); 340 } 341 342 void mlir::test::TestBranchingTransformOpTerminator::getEffects( 343 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {} 344 345 DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( 346 transform::TransformRewriter &rewriter, 347 transform::TransformResults &results, transform::TransformState &state) { 348 emitRemark() << getRemark(); 349 for (Operation *op : state.getPayloadOps(getTarget())) 350 rewriter.eraseOp(op); 351 352 if (getFailAfterErase()) 353 return emitSilenceableError() << "silenceable error"; 354 return DiagnosedSilenceableFailure::success(); 355 } 356 357 void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects( 358 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 359 transform::consumesHandle(getTargetMutable(), effects); 360 transform::modifiesPayload(effects); 361 } 362 363 DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( 364 transform::TransformRewriter &rewriter, Operation *target, 365 transform::ApplyToEachResultList &results, 366 transform::TransformState &state) { 367 OperationState opState(target->getLoc(), "foo"); 368 results.push_back(OpBuilder(target).create(opState)); 369 return DiagnosedSilenceableFailure::success(); 370 } 371 372 DiagnosedSilenceableFailure 373 mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( 374 transform::TransformRewriter &rewriter, Operation *target, 375 transform::ApplyToEachResultList &results, 376 transform::TransformState &state) { 377 static int count = 0; 378 if (count++ == 0) { 379 OperationState opState(target->getLoc(), "foo"); 380 results.push_back(OpBuilder(target).create(opState)); 381 } 382 return DiagnosedSilenceableFailure::success(); 383 } 384 385 DiagnosedSilenceableFailure 386 mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( 387 transform::TransformRewriter &rewriter, Operation *target, 388 transform::ApplyToEachResultList &results, 389 transform::TransformState &state) { 390 OperationState opState(target->getLoc(), "foo"); 391 results.push_back(OpBuilder(target).create(opState)); 392 results.push_back(OpBuilder(target).create(opState)); 393 return DiagnosedSilenceableFailure::success(); 394 } 395 396 DiagnosedSilenceableFailure 397 mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( 398 transform::TransformRewriter &rewriter, Operation *target, 399 transform::ApplyToEachResultList &results, 400 transform::TransformState &state) { 401 OperationState opState(target->getLoc(), "foo"); 402 results.push_back(nullptr); 403 results.push_back(OpBuilder(target).create(opState)); 404 return DiagnosedSilenceableFailure::success(); 405 } 406 407 DiagnosedSilenceableFailure 408 mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( 409 transform::TransformRewriter &rewriter, Operation *target, 410 transform::ApplyToEachResultList &results, 411 transform::TransformState &state) { 412 if (target->hasAttr("target_me")) 413 return DiagnosedSilenceableFailure::success(); 414 return emitDefaultSilenceableFailure(target); 415 } 416 417 DiagnosedSilenceableFailure 418 mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter, 419 transform::TransformResults &results, 420 transform::TransformState &state) { 421 results.set(llvm::cast<OpResult>(getCopy()), 422 state.getPayloadOps(getHandle())); 423 return DiagnosedSilenceableFailure::success(); 424 } 425 426 void mlir::test::TestCopyPayloadOp::getEffects( 427 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 428 transform::onlyReadsHandle(getHandleMutable(), effects); 429 transform::producesHandle(getOperation()->getOpResults(), effects); 430 transform::onlyReadsPayload(effects); 431 } 432 433 DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload( 434 Location loc, ArrayRef<Operation *> payload) const { 435 if (payload.empty()) 436 return DiagnosedSilenceableFailure::success(); 437 438 for (Operation *op : payload) { 439 if (op->getName().getDialectNamespace() != "test") { 440 return emitSilenceableError(loc) << "expected the payload operation to " 441 "belong to the 'test' dialect"; 442 } 443 } 444 445 return DiagnosedSilenceableFailure::success(); 446 } 447 448 DiagnosedSilenceableFailure mlir::transform::TestDialectParamType::checkPayload( 449 Location loc, ArrayRef<Attribute> payload) const { 450 for (Attribute attr : payload) { 451 auto integerAttr = llvm::dyn_cast<IntegerAttr>(attr); 452 if (integerAttr && integerAttr.getType().isSignlessInteger(32)) 453 continue; 454 return emitSilenceableError(loc) 455 << "expected the parameter to be a i32 integer attribute"; 456 } 457 458 return DiagnosedSilenceableFailure::success(); 459 } 460 461 void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects( 462 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 463 transform::onlyReadsHandle(getTargetMutable(), effects); 464 } 465 466 DiagnosedSilenceableFailure 467 mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( 468 transform::TransformRewriter &rewriter, 469 transform::TransformResults &results, transform::TransformState &state) { 470 int64_t count = 0; 471 for (Operation *op : state.getPayloadOps(getTarget())) { 472 op->walk([&](Operation *nested) { 473 SmallVector<Value> handles; 474 (void)state.getHandlesForPayloadOp(nested, handles); 475 count += handles.size(); 476 }); 477 } 478 emitRemark() << count << " handles nested under"; 479 return DiagnosedSilenceableFailure::success(); 480 } 481 482 DiagnosedSilenceableFailure 483 mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter, 484 transform::TransformResults &results, 485 transform::TransformState &state) { 486 SmallVector<uint32_t> values(/*Size=*/1, /*Value=*/0); 487 if (Value param = getParam()) { 488 values = llvm::to_vector( 489 llvm::map_range(state.getParams(param), [](Attribute attr) -> uint32_t { 490 return llvm::cast<IntegerAttr>(attr).getValue().getLimitedValue( 491 UINT32_MAX); 492 })); 493 } 494 495 Builder builder(getContext()); 496 SmallVector<Attribute> result = llvm::to_vector( 497 llvm::map_range(values, [this, &builder](uint32_t value) -> Attribute { 498 return builder.getI32IntegerAttr(value + getAddendum()); 499 })); 500 results.setParams(llvm::cast<OpResult>(getResult()), result); 501 return DiagnosedSilenceableFailure::success(); 502 } 503 504 DiagnosedSilenceableFailure 505 mlir::test::TestProduceParamWithNumberOfTestOps::apply( 506 transform::TransformRewriter &rewriter, 507 transform::TransformResults &results, transform::TransformState &state) { 508 Builder builder(getContext()); 509 SmallVector<Attribute> result = llvm::to_vector( 510 llvm::map_range(state.getPayloadOps(getHandle()), 511 [&builder](Operation *payload) -> Attribute { 512 int32_t count = 0; 513 payload->walk([&count](Operation *op) { 514 if (op->getName().getDialectNamespace() == "test") 515 ++count; 516 }); 517 return builder.getI32IntegerAttr(count); 518 })); 519 results.setParams(llvm::cast<OpResult>(getResult()), result); 520 return DiagnosedSilenceableFailure::success(); 521 } 522 523 DiagnosedSilenceableFailure 524 mlir::test::TestProduceParamOp::apply(transform::TransformRewriter &rewriter, 525 transform::TransformResults &results, 526 transform::TransformState &state) { 527 results.setParams(llvm::cast<OpResult>(getResult()), getAttr()); 528 return DiagnosedSilenceableFailure::success(); 529 } 530 531 void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects( 532 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 533 transform::onlyReadsHandle(getInMutable(), effects); 534 transform::producesHandle(getOperation()->getOpResults(), effects); 535 } 536 537 DiagnosedSilenceableFailure 538 mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( 539 transform::TransformRewriter &rewriter, Operation *target, 540 ::transform::ApplyToEachResultList &results, 541 ::transform::TransformState &state) { 542 Builder builder(getContext()); 543 if (getFirstResultIsParam()) { 544 results.push_back(builder.getI64IntegerAttr(0)); 545 } else if (getFirstResultIsNull()) { 546 results.push_back(nullptr); 547 } else { 548 results.push_back(*state.getPayloadOps(getIn()).begin()); 549 } 550 551 if (getSecondResultIsHandle()) { 552 results.push_back(*state.getPayloadOps(getIn()).begin()); 553 } else { 554 results.push_back(builder.getI64IntegerAttr(42)); 555 } 556 557 return DiagnosedSilenceableFailure::success(); 558 } 559 560 void mlir::test::TestProduceNullPayloadOp::getEffects( 561 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 562 transform::producesHandle(getOperation()->getOpResults(), effects); 563 } 564 565 DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( 566 transform::TransformRewriter &rewriter, 567 transform::TransformResults &results, transform::TransformState &state) { 568 SmallVector<Operation *, 1> null({nullptr}); 569 results.set(llvm::cast<OpResult>(getOut()), null); 570 return DiagnosedSilenceableFailure::success(); 571 } 572 573 DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply( 574 transform::TransformRewriter &rewriter, 575 transform::TransformResults &results, transform::TransformState &state) { 576 results.set(cast<OpResult>(getOut()), {}); 577 return DiagnosedSilenceableFailure::success(); 578 } 579 580 void mlir::test::TestProduceNullParamOp::getEffects( 581 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 582 transform::producesHandle(getOperation()->getOpResults(), effects); 583 } 584 585 DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply( 586 transform::TransformRewriter &rewriter, 587 transform::TransformResults &results, transform::TransformState &state) { 588 results.setParams(llvm::cast<OpResult>(getOut()), Attribute()); 589 return DiagnosedSilenceableFailure::success(); 590 } 591 592 void mlir::test::TestProduceNullValueOp::getEffects( 593 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 594 transform::producesHandle(getOperation()->getOpResults(), effects); 595 } 596 597 DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply( 598 transform::TransformRewriter &rewriter, 599 transform::TransformResults &results, transform::TransformState &state) { 600 results.setValues(llvm::cast<OpResult>(getOut()), {Value()}); 601 return DiagnosedSilenceableFailure::success(); 602 } 603 604 void mlir::test::TestRequiredMemoryEffectsOp::getEffects( 605 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 606 if (getHasOperandEffect()) 607 transform::consumesHandle(getInMutable(), effects); 608 609 if (getHasResultEffect()) { 610 transform::producesHandle(getOperation()->getOpResults(), effects); 611 } else { 612 effects.emplace_back(MemoryEffects::Read::get(), 613 llvm::cast<OpResult>(getOut()), 614 transform::TransformMappingResource::get()); 615 } 616 617 if (getModifiesPayload()) 618 transform::modifiesPayload(effects); 619 } 620 621 DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( 622 transform::TransformRewriter &rewriter, 623 transform::TransformResults &results, transform::TransformState &state) { 624 results.set(llvm::cast<OpResult>(getOut()), state.getPayloadOps(getIn())); 625 return DiagnosedSilenceableFailure::success(); 626 } 627 628 void mlir::test::TestTrackedRewriteOp::getEffects( 629 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 630 transform::onlyReadsHandle(getInMutable(), effects); 631 transform::modifiesPayload(effects); 632 } 633 634 void mlir::test::TestDummyPayloadOp::getEffects( 635 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 636 transform::producesHandle(getOperation()->getOpResults(), effects); 637 } 638 639 LogicalResult mlir::test::TestDummyPayloadOp::verify() { 640 if (getFailToVerify()) 641 return emitOpError() << "fail_to_verify is set"; 642 return success(); 643 } 644 645 DiagnosedSilenceableFailure 646 mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter, 647 transform::TransformResults &results, 648 transform::TransformState &state) { 649 int64_t numIterations = 0; 650 651 // `getPayloadOps` returns an iterator that skips ops that are erased in the 652 // loop body. Replacement ops are not enumerated. 653 for (Operation *op : state.getPayloadOps(getIn())) { 654 ++numIterations; 655 (void)op; 656 657 // Erase all payload ops. The outer loop should have only one iteration. 658 for (Operation *op : state.getPayloadOps(getIn())) { 659 rewriter.setInsertionPoint(op); 660 if (op->hasAttr("erase_me")) { 661 rewriter.eraseOp(op); 662 continue; 663 } 664 if (!op->hasAttr("replace_me")) { 665 continue; 666 } 667 668 SmallVector<NamedAttribute> attributes; 669 attributes.emplace_back(rewriter.getStringAttr("new_op"), 670 rewriter.getUnitAttr()); 671 OperationState opState(op->getLoc(), op->getName().getIdentifier(), 672 /*operands=*/ValueRange(), 673 /*types=*/op->getResultTypes(), attributes); 674 Operation *newOp = rewriter.create(opState); 675 rewriter.replaceOp(op, newOp->getResults()); 676 } 677 } 678 679 emitRemark() << numIterations << " iterations"; 680 return DiagnosedSilenceableFailure::success(); 681 } 682 683 namespace { 684 // Test pattern to replace an operation with a new op. 685 class ReplaceWithNewOp : public RewritePattern { 686 public: 687 ReplaceWithNewOp(MLIRContext *context) 688 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 689 690 LogicalResult matchAndRewrite(Operation *op, 691 PatternRewriter &rewriter) const override { 692 auto newName = op->getAttrOfType<StringAttr>("replace_with_new_op"); 693 if (!newName) 694 return failure(); 695 Operation *newOp = rewriter.create( 696 op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(), 697 op->getOperands(), op->getResultTypes()); 698 rewriter.replaceOp(op, newOp->getResults()); 699 return success(); 700 } 701 }; 702 703 // Test pattern to erase an operation. 704 class EraseOp : public RewritePattern { 705 public: 706 EraseOp(MLIRContext *context) 707 : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 708 LogicalResult matchAndRewrite(Operation *op, 709 PatternRewriter &rewriter) const override { 710 rewriter.eraseOp(op); 711 return success(); 712 } 713 }; 714 } // namespace 715 716 void mlir::test::ApplyTestPatternsOp::populatePatterns( 717 RewritePatternSet &patterns) { 718 patterns.insert<ReplaceWithNewOp, EraseOp>(patterns.getContext()); 719 } 720 721 void mlir::test::TestReEnterRegionOp::getEffects( 722 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 723 transform::consumesHandle(getOperation()->getOpOperands(), effects); 724 transform::modifiesPayload(effects); 725 } 726 727 DiagnosedSilenceableFailure 728 mlir::test::TestReEnterRegionOp::apply(transform::TransformRewriter &rewriter, 729 transform::TransformResults &results, 730 transform::TransformState &state) { 731 732 SmallVector<SmallVector<transform::MappedValue>> mappings; 733 for (BlockArgument arg : getBody().front().getArguments()) { 734 mappings.emplace_back(llvm::to_vector(llvm::map_range( 735 state.getPayloadOps(getOperand(arg.getArgNumber())), 736 [](Operation *op) -> transform::MappedValue { return op; }))); 737 } 738 739 for (int i = 0; i < 4; ++i) { 740 auto scope = state.make_region_scope(getBody()); 741 for (BlockArgument arg : getBody().front().getArguments()) { 742 if (failed(state.mapBlockArgument(arg, mappings[arg.getArgNumber()]))) 743 return DiagnosedSilenceableFailure::definiteFailure(); 744 } 745 for (Operation &op : getBody().front().without_terminator()) { 746 DiagnosedSilenceableFailure diag = 747 state.applyTransform(cast<transform::TransformOpInterface>(op)); 748 if (!diag.succeeded()) 749 return diag; 750 } 751 } 752 return DiagnosedSilenceableFailure::success(); 753 } 754 755 LogicalResult mlir::test::TestReEnterRegionOp::verify() { 756 if (getNumOperands() != getBody().front().getNumArguments()) { 757 return emitOpError() << "expects as many operands as block arguments"; 758 } 759 return success(); 760 } 761 762 DiagnosedSilenceableFailure mlir::test::TestNotifyPayloadOpReplacedOp::apply( 763 transform::TransformRewriter &rewriter, 764 transform::TransformResults &results, transform::TransformState &state) { 765 auto originalOps = state.getPayloadOps(getOriginal()); 766 auto replacementOps = state.getPayloadOps(getReplacement()); 767 if (llvm::range_size(originalOps) != llvm::range_size(replacementOps)) 768 return emitSilenceableError() << "expected same number of original and " 769 "replacement payload operations"; 770 for (const auto &[original, replacement] : 771 llvm::zip(originalOps, replacementOps)) { 772 if (failed( 773 rewriter.notifyPayloadOperationReplaced(original, replacement))) { 774 auto diag = emitSilenceableError() 775 << "unable to replace payload op in transform mapping"; 776 diag.attachNote(original->getLoc()) << "original payload op"; 777 diag.attachNote(replacement->getLoc()) << "replacement payload op"; 778 return diag; 779 } 780 } 781 return DiagnosedSilenceableFailure::success(); 782 } 783 784 void mlir::test::TestNotifyPayloadOpReplacedOp::getEffects( 785 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 786 transform::onlyReadsHandle(getOriginalMutable(), effects); 787 transform::onlyReadsHandle(getReplacementMutable(), effects); 788 } 789 790 DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne( 791 transform::TransformRewriter &rewriter, Operation *target, 792 transform::ApplyToEachResultList &results, 793 transform::TransformState &state) { 794 // Provide some IR that does not verify. 795 rewriter.setInsertionPointToStart(&target->getRegion(0).front()); 796 rewriter.create<TestDummyPayloadOp>(target->getLoc(), TypeRange(), 797 ValueRange(), /*failToVerify=*/true); 798 return DiagnosedSilenceableFailure::success(); 799 } 800 801 void mlir::test::TestProduceInvalidIR::getEffects( 802 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 803 transform::onlyReadsHandle(getTargetMutable(), effects); 804 transform::modifiesPayload(effects); 805 } 806 807 DiagnosedSilenceableFailure mlir::test::TestInitializerExtensionOp::apply( 808 transform::TransformRewriter &rewriter, 809 transform::TransformResults &results, transform::TransformState &state) { 810 std::string opName = 811 this->getOperationName().str() + "_" + getTypeAttr().str(); 812 TransformStateInitializerExtension *initExt = 813 state.getExtension<TransformStateInitializerExtension>(); 814 if (!initExt) { 815 emitRemark() << "\nSpecified extension not found, adding a new one!\n"; 816 SmallVector<std::string> opCollection = {opName}; 817 state.addExtension<TransformStateInitializerExtension>(1, opCollection); 818 } else { 819 initExt->setNumOp(initExt->getNumOp() + 1); 820 initExt->pushRegisteredOps(opName); 821 InFlightDiagnostic diag = emitRemark() 822 << "Number of currently registered op: " 823 << initExt->getNumOp() << "\n" 824 << initExt->printMessage() << "\n"; 825 } 826 return DiagnosedSilenceableFailure::success(); 827 } 828 829 namespace { 830 /// Test conversion pattern that replaces ops with the "replace_with_new_op" 831 /// attribute with "test.new_op". 832 class ReplaceWithNewOpConversion : public ConversionPattern { 833 public: 834 ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context) 835 : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(), 836 /*benefit=*/1, context) {} 837 838 LogicalResult 839 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 840 ConversionPatternRewriter &rewriter) const override { 841 if (!op->hasAttr("replace_with_new_op")) 842 return failure(); 843 SmallVector<Type> newResultTypes; 844 if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), 845 newResultTypes))) 846 return failure(); 847 Operation *newOp = rewriter.create( 848 op->getLoc(), 849 OperationName("test.new_op", op->getContext()).getIdentifier(), 850 operands, newResultTypes); 851 rewriter.replaceOp(op, newOp->getResults()); 852 return success(); 853 } 854 }; 855 } // namespace 856 857 void mlir::test::ApplyTestConversionPatternsOp::populatePatterns( 858 TypeConverter &typeConverter, RewritePatternSet &patterns) { 859 patterns.insert<ReplaceWithNewOpConversion>(typeConverter, 860 patterns.getContext()); 861 } 862 863 namespace { 864 /// Test type converter that converts tensor types to memref types. 865 class TestTypeConverter : public TypeConverter { 866 public: 867 TestTypeConverter() { 868 addConversion([](Type t) { return t; }); 869 addConversion([](RankedTensorType type) -> Type { 870 return MemRefType::get(type.getShape(), type.getElementType()); 871 }); 872 auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType, 873 ValueRange inputs, 874 Location loc) -> Value { 875 if (inputs.size() != 1) 876 return Value(); 877 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) 878 .getResult(0); 879 }; 880 addSourceMaterialization(unrealizedCastConverter); 881 addTargetMaterialization(unrealizedCastConverter); 882 } 883 }; 884 } // namespace 885 886 std::unique_ptr<::mlir::TypeConverter> 887 mlir::test::TestTypeConverterOp::getTypeConverter() { 888 return std::make_unique<TestTypeConverter>(); 889 } 890 891 namespace { 892 /// Test extension of the Transform dialect. Registers additional ops and 893 /// declares PDL as dependent dialect since the additional ops are using PDL 894 /// types for operands and results. 895 class TestTransformDialectExtension 896 : public transform::TransformDialectExtension< 897 TestTransformDialectExtension> { 898 public: 899 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformDialectExtension) 900 901 using Base::Base; 902 903 void init() { 904 declareDependentDialect<pdl::PDLDialect>(); 905 registerTransformOps<TestTransformOp, 906 TestTransformUnrestrictedOpNoInterface, 907 #define GET_OP_LIST 908 #include "TestTransformDialectExtension.cpp.inc" 909 >(); 910 registerTypes< 911 #define GET_TYPEDEF_LIST 912 #include "TestTransformDialectExtensionTypes.cpp.inc" 913 >(); 914 915 auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &, 916 ArrayRef<PDLValue> pdlValues) { 917 for (const PDLValue &pdlValue : pdlValues) { 918 if (Operation *op = pdlValue.dyn_cast<Operation *>()) { 919 op->emitWarning() << "from PDL constraint"; 920 } 921 } 922 return success(); 923 }; 924 925 addDialectDataInitializer<transform::PDLMatchHooks>( 926 [&](transform::PDLMatchHooks &hooks) { 927 llvm::StringMap<PDLConstraintFunction> constraints; 928 constraints.try_emplace("verbose_constraint", verboseConstraint); 929 hooks.mergeInPDLMatchHooks(std::move(constraints)); 930 }); 931 } 932 }; 933 } // namespace 934 935 // These are automatically generated by ODS but are not used as the Transform 936 // dialect uses a different dispatch mechanism to support dialect extensions. 937 LLVM_ATTRIBUTE_UNUSED static OptionalParseResult 938 generatedTypeParser(AsmParser &parser, StringRef *mnemonic, Type &value); 939 LLVM_ATTRIBUTE_UNUSED static LogicalResult 940 generatedTypePrinter(Type def, AsmPrinter &printer); 941 942 #define GET_TYPEDEF_CLASSES 943 #include "TestTransformDialectExtensionTypes.cpp.inc" 944 945 #define GET_OP_CLASSES 946 #include "TestTransformDialectExtension.cpp.inc" 947 948 void ::test::registerTestTransformDialectExtension(DialectRegistry ®istry) { 949 registry.addExtensions<TestTransformDialectExtension>(); 950 } 951